/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data.sketch.countdistinct;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
import org.apache.sysds.runtime.matrix.data.sketch.countdistinct.BitMapValueCombiner;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class CountDistinctFunctionSketch
extends CountDistinctSketch {
    public CountDistinctFunctionSketch(Operator op) {
        super(op);
    }

    @Override
    public MatrixBlock getValue(MatrixBlock blkIn) {
        return null;
    }

    @Override
    public MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn) {
        MatrixBlock blkInCorr = blkIn.getCorrection();
        MatrixBlock blkOut = new MatrixBlock(1, 1, false);
        long res = 0L;
        for (int i = 0; i < blkInCorr.getNumRows(); ++i) {
            res = (long)((double)res + blkInCorr.getValue(i, 1));
        }
        blkOut.setValue(0, 0, res);
        return blkOut;
    }

    @Override
    public CorrMatrixBlock create(MatrixBlock blkIn) {
        int R = blkIn.getNumRows();
        int C = blkIn.getNumColumns();
        if (R == 1 && R == C) {
            MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
            blkOutCorr.setValue(0, 1, 1.0);
            return new CorrMatrixBlock(blkIn, blkOutCorr);
        }
        if (blkIn.isEmpty()) {
            MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
            return new CorrMatrixBlock(blkIn, blkOutCorr);
        }
        HashMap<Short, Set<Long>> bitMap = new HashMap<Short, Set<Long>>();
        int maxColumns = (int)Math.pow(1000.0, 2.0);
        for (int i = 0; i < R; ++i) {
            for (int j = 0; j < C; ++j) {
                short key = (short)this.extractRightKBitsFromIndex((long)blkIn.getValue(i, j), 52, 12);
                long value = this.extractRightKBitsFromIndex((long)blkIn.getValue(i, j), 0, 52);
                Set fractions = bitMap.getOrDefault(key, new HashSet());
                fractions.add(value);
                bitMap.put(key, fractions);
                maxColumns = Math.max(maxColumns, fractions.size());
            }
        }
        MatrixBlock blkOutCorr = this.serialize(bitMap, maxColumns);
        return new CorrMatrixBlock(blkIn, blkOutCorr);
    }

    private long extractRightKBitsFromIndex(long n, int startingIndex, int k) {
        long kMask = (1 << k) - 1;
        return kMask & n >> startingIndex;
    }

    private MatrixBlock serialize(Map<Short, Set<Long>> bitMap, int maxWidth) {
        MatrixBlock blkOut = new MatrixBlock(bitMap.size(), maxWidth + 2, false);
        int i = 0;
        for (short key : bitMap.keySet()) {
            Set<Long> fractions = bitMap.get(key);
            blkOut.setValue(i, 0, key);
            blkOut.setValue(i, 1, fractions.size());
            int j = 2;
            for (long fraction : fractions) {
                blkOut.setValue(i, j, fraction);
                ++j;
            }
            ++i;
        }
        return blkOut;
    }

    private Map<Short, Set<Long>> deserialize(MatrixBlock blkIn) {
        int R = blkIn.getNumRows();
        HashMap<Short, Set<Long>> bitMap = new HashMap<Short, Set<Long>>();
        for (int i = 0; i < R; ++i) {
            short key = (short)blkIn.getValue(i, 0);
            Set fractions = bitMap.getOrDefault(key, new HashSet());
            int C = (int)blkIn.getValue(i, 1);
            for (int j = 0; j < C; ++j) {
                long fraction = (long)blkIn.getValue(i, j + 2);
                fractions.add(fraction);
            }
            bitMap.put(key, fractions);
        }
        return bitMap;
    }

    @Override
    public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1) {
        MatrixBlock corr0 = arg0.getCorrection();
        Map<Short, Set<Long>> bitMap0 = this.deserialize(corr0);
        MatrixBlock corr1 = arg1.getCorrection();
        Map<Short, Set<Long>> bitMap1 = this.deserialize(corr1);
        Map<Short, Set<Long>> bitMapOut = Stream.concat(bitMap0.entrySet().stream(), bitMap1.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, new BitMapValueCombiner()));
        OptionalInt maxWidthOpt = bitMapOut.values().stream().mapToInt(Set::size).max();
        if (maxWidthOpt.isEmpty()) {
            throw new IllegalArgumentException("Corrupt sketch: metadata is invalid");
        }
        int maxWidth = maxWidthOpt.getAsInt();
        MatrixBlock blkOutCorr = this.serialize(bitMapOut, maxWidth);
        return new CorrMatrixBlock(arg0.getValue(), blkOutCorr);
    }

    @Override
    public CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1) {
        return null;
    }
}

