/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.lineage;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import jcuda.Pointer;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheEntry;
import org.apache.sysds.runtime.lineage.LineageCacheEviction;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class LineageGPUCacheEviction {
    private static GPUContext _gpuContext = null;
    public static ExecutorService gpuEvictionThread = null;
    private static HashMap<Long, TreeSet<LineageCacheEntry>> freeQueues = new HashMap();
    private static HashMap<Pointer, Integer> livePointers = new HashMap();
    private static HashMap<Pointer, LineageCacheEntry> GPUCacheEntries = new HashMap();

    protected static void resetEviction() {
        gpuEvictionThread = null;
        freeQueues.clear();
        livePointers.clear();
        GPUCacheEntries.clear();
    }

    protected static void incrementLiveCount(Pointer ptr) {
        if (livePointers.merge(ptr, 1, Integer::sum) == 1) {
            freeQueues.get(LineageGPUCacheEviction.getPointerSize(ptr)).remove(GPUCacheEntries.get(ptr));
        }
    }

    public static void decrementLiveCount(Pointer ptr) {
        if (livePointers.compute(ptr, (k, v) -> v == 1 ? null : Integer.valueOf(v - 1)) == null) {
            long size = LineageGPUCacheEviction.getPointerSize(ptr);
            if (!freeQueues.containsKey(size)) {
                freeQueues.put(size, new TreeSet<LineageCacheEntry>(LineageCacheConfig.LineageGPUCacheComparator));
            }
            freeQueues.get(size).add(GPUCacheEntries.get(ptr));
        }
    }

    public static boolean probeLiveCachedPointers(Pointer ptr) {
        return livePointers.containsKey(ptr);
    }

    protected static void addEntry(LineageCacheEntry entry) {
        if (entry.isNullVal()) {
            return;
        }
        if (entry.isScalarValue()) {
            throw new DMLRuntimeException("Scalars are never stored in GPU. Lineage: " + entry._key);
        }
        entry.initiateScoreGPU(LineageCacheEviction._removelist);
        livePointers.put(entry.getGPUPointer(), 1);
        GPUCacheEntries.put(entry.getGPUPointer(), entry);
    }

    protected static void maintainOrder(LineageCacheEntry entry) {
        if (entry.getCacheStatus() != LineageCacheConfig.LineageCacheStatus.GPUCACHED) {
            return;
        }
        entry.updateTimestamp();
    }

    protected static void removeSingleEntry(Map<LineageItem, LineageCacheEntry> cache, LineageCacheEntry e) {
        cache.remove(e._key);
        LineageCacheEviction._removelist.merge(e._key, 1, Integer::sum);
    }

    private static void removeEntry(LineageCacheEntry e) {
        Map<LineageItem, LineageCacheEntry> cache = LineageCache.getLineageCache();
        if (e._origItem == null) {
            LineageGPUCacheEviction.removeSingleEntry(cache, e);
            return;
        }
        LineageCacheEntry tmp = cache.get(e._origItem);
        while (tmp != null) {
            LineageGPUCacheEviction.removeSingleEntry(cache, tmp);
            tmp = tmp._nextEntry;
        }
    }

    public static void removeAllEntries(double evictFrac) {
        ArrayList<Long> sizes = new ArrayList<Long>(freeQueues.keySet());
        block0: for (Long size : sizes) {
            TreeSet<LineageCacheEntry> freeList = freeQueues.get(size);
            int evictLim = (int)((double)freeList.size() * evictFrac);
            int evictCount = 1;
            LineageCacheEntry le = LineageGPUCacheEviction.pollFirstFreeEntry(size);
            while (le != null) {
                _gpuContext.getMemoryManager().guardedCudaFree(le.getGPUPointer());
                if (DMLScript.STATISTICS) {
                    LineageCacheStatistics.incrementGpuDel();
                }
                le = LineageGPUCacheEviction.pollFirstFreeEntry(size);
                if (evictCount > evictLim) continue block0;
                ++evictCount;
            }
        }
    }

    public static void setGPUContext(GPUContext gpuCtx) {
        _gpuContext = gpuCtx;
    }

    public static boolean isGPUCacheFreeQEmpty() {
        return freeQueues.isEmpty();
    }

    public static LineageCacheEntry pollFirstFreeEntry(long size) {
        TreeSet<LineageCacheEntry> freeList = freeQueues.get(size);
        if (freeList != null && freeList.isEmpty()) {
            freeQueues.remove(size);
        }
        LineageCacheEntry e = null;
        if (freeList != null && !freeList.isEmpty()) {
            e = freeList.pollFirst();
            if (LineageGPUCacheEviction.probeLiveCachedPointers(e.getGPUPointer())) {
                throw new DMLRuntimeException("Recycling live pointer: " + e._key);
            }
            LineageGPUCacheEviction.removeEntry(e);
            GPUCacheEntries.remove(e.getGPUPointer());
            return e;
        }
        return null;
    }

    public static LineageCacheEntry pollFistFreeNotExact(long size) {
        ArrayList<Long> sortedSizes = new ArrayList<Long>(freeQueues.keySet());
        Collections.sort(sortedSizes);
        long maxSize = (Long)sortedSizes.get(sortedSizes.size() - 1);
        if (size > maxSize) {
            return LineageGPUCacheEviction.pollFirstFreeEntry(maxSize);
        }
        Iterator iterator = sortedSizes.iterator();
        while (iterator.hasNext()) {
            long fSize = (Long)iterator.next();
            if (fSize < size) continue;
            return LineageGPUCacheEviction.pollFirstFreeEntry(fSize);
        }
        return null;
    }

    public static int numPointersCached() {
        return freeQueues.values().stream().mapToInt(TreeSet::size).sum();
    }

    public static long totalMemoryCached() {
        long totFree = 0L;
        for (Map.Entry<Long, TreeSet<LineageCacheEntry>> entry : freeQueues.entrySet()) {
            totFree += entry.getKey() * (long)entry.getValue().size();
        }
        return totFree;
    }

    protected static long getPointerSize(Pointer ptr) {
        return _gpuContext.getMemoryManager().getSizeAllocatedGPUPointer(ptr);
    }

    public static Set<Pointer> getAllCachedPointers() {
        HashSet<Pointer> cachedPointers = new HashSet<Pointer>();
        for (Map.Entry<Long, TreeSet<LineageCacheEntry>> entry : freeQueues.entrySet()) {
            cachedPointers.addAll(entry.getValue().stream().map(LineageCacheEntry::getGPUPointer).collect(Collectors.toSet()));
        }
        return cachedPointers;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Pointer copyToHostCache(LineageCacheEntry entry) {
        long t0 = System.nanoTime();
        MatrixBlock mb = LineageGPUCacheEviction.pointerToMatrixBlock(entry);
        long t1 = System.nanoTime();
        LineageGPUCacheEviction.adjustD2HTransferSpeed(entry.getSize(), (double)(t1 - t0) / 1.0E9);
        Pointer ptr = entry.getGPUPointer();
        long size = mb.getInMemorySize();
        Map<LineageItem, LineageCacheEntry> map = LineageCache.getLineageCache();
        synchronized (map) {
            if (!LineageCacheEviction.isBelowThreshold(size)) {
                Map<LineageItem, LineageCacheEntry> map2 = LineageCache.getLineageCache();
                synchronized (map2) {
                    LineageCacheEviction.makeSpace(LineageCache.getLineageCache(), size);
                }
            }
            LineageCacheEviction.updateSize(size, true);
            entry.setValue(mb);
            LineageCacheEviction.addEntry(entry);
        }
        return ptr;
    }

    private static void adjustD2HTransferSpeed(double sizeByte, double copyTime) {
        double sizeMB = sizeByte / 1048576.0;
        double newTSpeed = sizeMB / copyTime;
        if (newTSpeed > LineageCacheConfig.D2HMAXBANDWIDTH) {
            return;
        }
        double smFactor = 0.5;
        LineageCacheConfig.D2HCOPYBANDWIDTH = smFactor * newTSpeed + (1.0 - smFactor) * LineageCacheConfig.D2HCOPYBANDWIDTH;
    }

    private static MatrixBlock pointerToMatrixBlock(LineageCacheEntry le) {
        MatrixBlock ret = null;
        DataCharacteristics dc = le.getDataCharacteristics();
        if (!le.isDensePointer()) {
            throw new DMLRuntimeException("Sparse pointers should not be cached in GPU. Lineage: " + le._key);
        }
        ret = new MatrixBlock(GPUObject.toIntExact(dc.getRows()), GPUObject.toIntExact(dc.getCols()), false);
        ret.allocateDenseBlock();
        LibMatrixCUDA.cudaSupportFunctions.deviceToHost(_gpuContext, le.getGPUPointer(), ret.getDenseBlockValues(), null, true);
        ret.recomputeNonZeros();
        return ret;
    }

    public static void removeFromDeviceCache(LineageCacheEntry entry, Pointer ptr, boolean removeFromCache) {
        if (removeFromCache) {
            LineageCache.removeEntry(entry._key);
        }
        GPUCacheEntries.remove(ptr);
    }
}

