/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

public final class CLALibRexpand {
    public static boolean ALLOW_COMPRESSED_TABLE_SEQ = false;
    protected static final Log LOG = LogFactory.getLog((String)CLALibRexpand.class.getName());

    private CLALibRexpand() {
    }

    public static MatrixBlock rexpand(CompressedMatrixBlock in, MatrixBlock ret, double max, boolean rows, boolean cast, boolean ignore, int k) {
        if (rows) {
            return in.getUncompressed("Rexpand in rows direction (one hot encode)").rexpandOperations(ret, max, rows, cast, ignore, k);
        }
        return CLALibRexpand.rexpandCols(in, max, cast, ignore, k);
    }

    public static MatrixBlock rexpand(int seqHeight, MatrixBlock A) {
        return CLALibRexpand.rexpand(seqHeight, A, -1);
    }

    public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut) {
        return CLALibRexpand.rexpand(seqHeight, A, nColOut, 1);
    }

    public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int k) {
        try {
            int[] map = new int[seqHeight];
            Pair<Integer, Integer> meta = CLALibRexpand.constructInitialMapping(map, A, k, nColOut);
            int maxCol = meta.getKey();
            int nZeros = meta.getValue();
            boolean containsNull = maxCol < 0;
            maxCol = Math.abs(maxCol);
            boolean cutOff = false;
            if (nColOut == -1) {
                nColOut = maxCol;
            } else if (nColOut < maxCol) {
                cutOff = true;
            }
            if (containsNull) {
                CLALibRexpand.correctNulls(map, nColOut);
            }
            if (nColOut == 0) {
                return new MatrixBlock(seqHeight, 0, 0.0);
            }
            return CLALibRexpand.createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed table seq operator", e);
        }
    }

    private static MatrixBlock rexpandCols(CompressedMatrixBlock in, double max, boolean cast, boolean ignore, int k) {
        return CLALibRexpand.rexpandCols(in, UtilFunctions.toInt(max), cast, ignore, k);
    }

    private static MatrixBlock rexpandCols(CompressedMatrixBlock in, int max, boolean cast, boolean ignore, int k) {
        LibMatrixReorg.checkRexpand(in, ignore);
        int nRows = in.getNumRows();
        if (in.isEmptyBlock(false)) {
            return new MatrixBlock(nRows, max, true);
        }
        if (in.isOverlapping() || in.getColGroups().size() > 1) {
            return LibMatrixReorg.rexpand(in.getUncompressed("Rexpand (one hot encode)"), new MatrixBlock(), max, false, cast, ignore, k);
        }
        CompressedMatrixBlock retC = new CompressedMatrixBlock(nRows, max);
        AColGroup g = in.getColGroups().get(0).rexpandCols(max, ignore, cast, nRows);
        if (g == null) {
            return new MatrixBlock(nRows, 0, 0L);
        }
        retC.setNumColumns(g.getNumCols());
        retC.allocateColGroup(g);
        retC.recomputeNonZeros();
        return retC;
    }

    private static CompressedMatrixBlock createCompressedReturn(int[] map, int nColOut, int seqHeight, int nNulls, boolean containsNull, int k) throws Exception {
        IColIndex i = ColIndexFactory.create(0, nColOut);
        IDictionary d = IdentityDictionary.create(nColOut, containsNull);
        AMapToData m = MapToFactory.create(seqHeight, map, nColOut + (containsNull ? 1 : 0), k);
        AColGroup g = ColGroupDDC.create(i, d, m, null);
        CompressedMatrixBlock cmb = new CompressedMatrixBlock(seqHeight, nColOut);
        cmb.allocateColGroup(g);
        cmb.setNonZeros(seqHeight - nNulls);
        return cmb;
    }

    private static int correctNulls(int[] map, int nColOut) {
        int nNulls = 0;
        for (int i = 0; i < map.length; ++i) {
            if (map[i] != -1) continue;
            map[i] = nColOut;
            ++nNulls;
        }
        return nNulls;
    }

    private static Pair<Integer, Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
        MatrixBlock Ac;
        if (A.isEmpty() || A.isInSparseFormat()) {
            throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
        }
        if (A instanceof CompressedMatrixBlock) {
            LOG.warn((Object)"Decompression of right side input to CLALibTable, please implement alternative.");
            Ac = ((CompressedMatrixBlock)A).getUncompressed("rexpand", k);
        } else {
            Ac = A;
        }
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            int blkz = Math.max(map.length / k, 1000);
            ArrayList<Future<Pair>> tasks = new ArrayList<Future<Pair>>();
            for (int i = 0; i < map.length; i += blkz) {
                int start = i;
                int end = Math.min(i + blkz, map.length);
                tasks.add(pool.submit(() -> CLALibRexpand.partialMapping(map, Ac, start, end, maxOutCol)));
            }
            int maxCol = 0;
            int zeros = 0;
            for (Future future : tasks) {
                int tmpMaxCol = (Integer)((Pair)future.get()).getKey();
                int tmpZeros = (Integer)((Pair)future.get()).getValue();
                if (Math.abs(tmpMaxCol) > Math.abs(maxCol)) {
                    maxCol = tmpMaxCol;
                }
                zeros += tmpZeros;
            }
            Pair<Integer, Integer> pair = new Pair<Integer, Integer>(maxCol, zeros);
            return pair;
        }
        catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {
        int maxCol = 0;
        int zeros = 0;
        double[] aVals = A.getDenseBlockValues();
        for (int i = start; i < end; ++i) {
            double v2 = aVals[i];
            int colUnsafe = UtilFunctions.toInt(v2);
            if (!Double.isNaN(v2) && colUnsafe < 0) {
                throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2);
            }
            int invalid = Double.isNaN(v2) || maxOutCol != -1 && colUnsafe > maxOutCol ? 1 : 0;
            int colSafe = maxOutCol * invalid + (colUnsafe - 1) * (1 - invalid);
            zeros += invalid;
            maxCol = Math.max(colUnsafe, maxCol);
            map[i] = colSafe;
        }
        if (maxOutCol == -1 && zeros > 0) {
            maxCol *= -1;
        }
        return new Pair<Integer, Integer>(maxCol, zeros);
    }

    public static boolean compressedTableSeq() {
        return ALLOW_COMPRESSED_TABLE_SEQ || ConfigurationManager.isCompressionEnabled();
    }
}

