/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.estim;

import java.io.Serializable;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.lang.ArrayUtils;
import org.apache.directory.api.util.exception.NotImplementedException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.estim.MMNode;
import org.apache.sysml.hops.estim.SparsityEstimator;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;

public class EstimatorMatrixHistogram
extends SparsityEstimator {
    private static final boolean DEFAULT_USE_EXCEPTS = true;
    private final boolean _useExcepts;

    public EstimatorMatrixHistogram() {
        this(true);
    }

    public EstimatorMatrixHistogram(boolean useExcepts) {
        this._useExcepts = useExcepts;
    }

    @Override
    public MatrixCharacteristics estim(MMNode root) {
        if (!root.getLeft().isLeaf()) {
            this.estim(root.getLeft());
        }
        if (!root.getRight().isLeaf()) {
            this.estim(root.getRight());
        }
        MatrixHistogram h1 = !root.getLeft().isLeaf() ? (MatrixHistogram)root.getLeft().getSynopsis() : new MatrixHistogram(root.getLeft().getData(), this._useExcepts);
        MatrixHistogram h2 = !root.getRight().isLeaf() ? (MatrixHistogram)root.getRight().getSynopsis() : new MatrixHistogram(root.getRight().getData(), this._useExcepts);
        double ret = this.estimIntern(h1, h2, root.getOp());
        MatrixHistogram outMap = MatrixHistogram.deriveOutputHistogram(h1, h2, ret, root.getOp());
        root.setSynopsis(outMap);
        return root.setMatrixCharacteristics(new MatrixCharacteristics(outMap.getRows(), outMap.getCols(), outMap.getNonZeros()));
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        return this.estim(m1, m2, SparsityEstimator.OpCode.MM);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        if (this.isExactMetadataOp(op)) {
            return this.estimExactMetaData(m1.getMatrixCharacteristics(), m2.getMatrixCharacteristics(), op).getSparsity();
        }
        MatrixHistogram h1 = new MatrixHistogram(m1, this._useExcepts);
        MatrixHistogram h2 = m1 == m2 ? h1 : new MatrixHistogram(m2, this._useExcepts);
        return this.estimIntern(h1, h2, op);
    }

    @Override
    public double estim(MatrixBlock m1, SparsityEstimator.OpCode op) {
        if (this.isExactMetadataOp(op)) {
            return this.estimExactMetaData(m1.getMatrixCharacteristics(), null, op).getSparsity();
        }
        MatrixHistogram h1 = new MatrixHistogram(m1, this._useExcepts);
        return this.estimIntern(h1, null, op);
    }

    private double estimIntern(MatrixHistogram h1, MatrixHistogram h2, SparsityEstimator.OpCode op) {
        double msize = (double)h1.getRows() * (double)h1.getCols();
        switch (op) {
            case MM: {
                return this.estimInternMM(h1, h2);
            }
            case MULT: {
                double N1 = h1.getNonZeros();
                double N2 = h2.getNonZeros();
                long scale = IntStream.range(0, h1.getCols()).mapToLong(j -> (long)h1.cNnz[j] * (long)h2.cNnz[j]).sum();
                return IntStream.range(0, h1.getRows()).mapToDouble(i -> (double)((long)h1.rNnz[i] * (long)h2.rNnz[i] * scale) / N1 / N2).sum() / msize;
            }
            case PLUS: {
                double N1 = h1.getNonZeros();
                double N2 = h2.getNonZeros();
                long scale = IntStream.range(0, h1.getCols()).mapToLong(j -> (long)h1.cNnz[j] * (long)h2.cNnz[j]).sum();
                return IntStream.range(0, h1.getRows()).mapToDouble(i -> (double)((long)h1.rNnz[i] + (long)h2.rNnz[i]) - (double)((long)h1.rNnz[i] * (long)h2.rNnz[i] * scale) / N1 / N2).sum() / msize;
            }
            case EQZERO: {
                return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), (long)h1.getRows() * (long)h1.getCols() - h1.getNonZeros());
            }
            case DIAG: {
                return h1.getCols() == 1 ? OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) : OptimizerUtils.getSparsity(h1.getRows(), 1L, Math.min((long)h1.getRows(), h1.getNonZeros()));
            }
            case CBIND: {
                return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols() + h2.getCols(), h1.getNonZeros() + h2.getNonZeros());
            }
            case RBIND: {
                return OptimizerUtils.getSparsity(h1.getRows() + h2.getRows(), h1.getCols(), h1.getNonZeros() + h2.getNonZeros());
            }
            case NEQZERO: 
            case TRANS: 
            case RESHAPE: {
                return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros());
            }
        }
        throw new NotImplementedException();
    }

    private double estimInternMM(MatrixHistogram h1, MatrixHistogram h2) {
        long nnz = 0L;
        if (h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1) {
            for (int j = 0; j < h1.getCols(); ++j) {
                nnz += (long)(h1.cNnz[j] * h2.rNnz[j]);
            }
        } else if (h1.cNnz1e != null && h2.rNnz1e != null) {
            long mnOut = (h1.rNonEmpty - h1.rN1) * (h2.cNonEmpty - h2.cN1);
            double spOutRest = 0.0;
            for (int j = 0; j < h1.getCols(); ++j) {
                nnz += (long)(h1.cNnz1e[j] * h2.rNnz[j]);
                nnz += (long)((h1.cNnz[j] - h1.cNnz1e[j]) * h2.rNnz1e[j]);
                double lsp = (double)(h1.cNnz[j] - h1.cNnz1e[j]) * (double)(h2.rNnz[j] - h2.rNnz1e[j]) / (double)mnOut;
                spOutRest = spOutRest + lsp - spOutRest * lsp;
            }
            nnz += (long)(spOutRest * (double)mnOut);
        } else {
            long mnOut = h1.getRows() * h2.getCols();
            double spOut = 0.0;
            for (int j = 0; j < h1.getCols(); ++j) {
                double lsp = (double)h1.cNnz[j] * (double)h2.rNnz[j] / (double)mnOut;
                spOut = spOut + lsp - spOut * lsp;
            }
            nnz = (long)(spOut * (double)mnOut);
        }
        nnz = h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0 ? Math.min((long)h1.rNonEmpty * (long)h2.cNonEmpty, nnz) : nnz;
        nnz = h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0 ? Math.max((long)h1.rNdiv2 * (long)h2.cNdiv2, nnz) : nnz;
        return OptimizerUtils.getSparsity(h1.getRows(), h2.getCols(), nnz);
    }

    private static class MatrixHistogram {
        private final int[] rNnz;
        private int[] rNnz1e;
        private final int[] cNnz;
        private int[] cNnz1e;
        private final int rMaxNnz;
        private final int cMaxNnz;
        private final int rN1;
        private final int cN1;
        private final int rNonEmpty;
        private final int cNonEmpty;
        private final int rNdiv2;
        private final int cNdiv2;
        private boolean fullDiag;

        public MatrixHistogram(MatrixBlock in, boolean useExcepts) {
            block12: {
                int i;
                this.rNnz1e = null;
                this.cNnz1e = null;
                int m = in.getNumRows();
                int n = in.getNumColumns();
                this.rNnz = new int[in.getNumRows()];
                this.cNnz = new int[in.getNumColumns()];
                boolean bl = this.fullDiag = (long)in.getNumRows() == in.getNonZeros() && in.getNumRows() == in.getNumColumns();
                if (!in.isEmpty()) {
                    int i2;
                    Serializable a;
                    if (in.isInSparseFormat()) {
                        a = in.getSparseBlock();
                        for (i2 = 0; i2 < m; ++i2) {
                            if (((SparseBlock)a).isEmpty(i2)) continue;
                            int apos = ((SparseBlock)a).pos(i2);
                            int alen = ((SparseBlock)a).size(i2);
                            int[] aix = ((SparseBlock)a).indexes(i2);
                            this.rNnz[i2] = alen;
                            LibMatrixAgg.countAgg(((SparseBlock)a).values(i2), this.cNnz, aix, apos, alen);
                            this.fullDiag &= aix[apos] == i2;
                        }
                    } else {
                        a = in.getDenseBlock();
                        for (i2 = 0; i2 < m; ++i2) {
                            this.rNnz[i2] = ((DenseBlock)a).countNonZeros(i2);
                            LibMatrixAgg.countAgg(((DenseBlock)a).values(i2), this.cNnz, ((DenseBlock)a).pos(i2), n);
                            this.fullDiag &= this.rNnz[i2] == 1 && n > i2 && ((DenseBlock)a).get(i2, i2) != 0.0;
                        }
                    }
                }
                int[] rSummary = MatrixHistogram.deriveSummaryStatistics(this.rNnz, this.getCols());
                int[] cSummary = MatrixHistogram.deriveSummaryStatistics(this.cNnz, this.getRows());
                this.rMaxNnz = rSummary[0];
                this.cMaxNnz = cSummary[0];
                this.rN1 = rSummary[1];
                this.cN1 = cSummary[1];
                this.rNonEmpty = rSummary[2];
                this.cNonEmpty = cSummary[2];
                this.rNdiv2 = rSummary[3];
                this.cNdiv2 = cSummary[3];
                if (!(useExcepts & !in.isEmpty()) || this.rMaxNnz <= 1 && this.cMaxNnz <= 1) break block12;
                this.rNnz1e = new int[in.getNumRows()];
                this.cNnz1e = new int[in.getNumColumns()];
                if (in.isInSparseFormat()) {
                    SparseBlock a = in.getSparseBlock();
                    for (i = 0; i < m; ++i) {
                        if (a.isEmpty(i)) continue;
                        int alen = a.size(i);
                        int apos = a.pos(i);
                        int[] aix = a.indexes(i);
                        for (int k = apos; k < apos + alen; ++k) {
                            if (this.cNnz[aix[k]] > 1) continue;
                            int n2 = i;
                            this.rNnz1e[n2] = this.rNnz1e[n2] + 1;
                        }
                        if (alen != 1) continue;
                        int n3 = aix[apos];
                        this.cNnz1e[n3] = this.cNnz1e[n3] + 1;
                    }
                } else {
                    DenseBlock a = in.getDenseBlock();
                    for (i = 0; i < m; ++i) {
                        double[] avals = a.values(i);
                        int aix = a.pos(i);
                        boolean rNnzlte1 = this.rNnz[i] <= 1;
                        for (int j = 0; j < n; ++j) {
                            if (avals[aix + j] == 0.0) continue;
                            if (this.cNnz[j] <= 1) {
                                int n4 = i;
                                this.rNnz1e[n4] = this.rNnz1e[n4] + 1;
                            }
                            if (!rNnzlte1) continue;
                            int n5 = j;
                            this.cNnz1e[n5] = this.cNnz1e[n5] + 1;
                        }
                    }
                }
            }
        }

        public MatrixHistogram(int[] r, int[] r1e, int[] c, int[] c1e, int rmax, int cmax) {
            this.rNnz1e = null;
            this.cNnz1e = null;
            this.rNnz = r;
            this.rNnz1e = r1e;
            this.cNnz = c;
            this.cNnz1e = c1e;
            this.rMaxNnz = rmax;
            this.cMaxNnz = cmax;
            this.cN1 = -1;
            this.rN1 = -1;
            this.cNonEmpty = -1;
            this.rNonEmpty = -1;
            this.cNdiv2 = -1;
            this.rNdiv2 = -1;
        }

        public int getRows() {
            return this.rNnz.length;
        }

        public int getCols() {
            return this.cNnz.length;
        }

        public long getNonZeros() {
            return this.getRows() < this.getCols() ? IntStream.range(0, this.getRows()).mapToLong(i -> this.rNnz[i]).sum() : IntStream.range(0, this.getCols()).mapToLong(i -> this.cNnz[i]).sum();
        }

        public static MatrixHistogram deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut, SparsityEstimator.OpCode op) {
            switch (op) {
                case MM: {
                    return MatrixHistogram.deriveMMHistogram(h1, h2, spOut);
                }
                case MULT: {
                    return MatrixHistogram.deriveMultHistogram(h1, h2);
                }
                case PLUS: {
                    return MatrixHistogram.derivePlusHistogram(h1, h2);
                }
                case RBIND: {
                    return MatrixHistogram.deriveRbindHistogram(h1, h2);
                }
                case CBIND: {
                    return MatrixHistogram.deriveCbindHistogram(h1, h2);
                }
            }
            throw new NotImplementedException();
        }

        private static MatrixHistogram deriveMMHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) {
            if (h1.fullDiag) {
                return h2;
            }
            if (h2.fullDiag) {
                return h1;
            }
            long nnz1 = h1.getNonZeros();
            long nnz2 = h2.getNonZeros();
            double nnzOut = spOut * (double)h1.getRows() * (double)h2.getCols();
            int rMaxNnz = 0;
            int cMaxNnz = 0;
            int[] rNnz = new int[h1.getRows()];
            Random rn = new Random();
            for (int i = 0; i < h1.getRows(); ++i) {
                rNnz[i] = MatrixHistogram.probRound(nnzOut / (double)nnz1 * (double)h1.rNnz[i], rn);
                rMaxNnz = Math.max(rMaxNnz, rNnz[i]);
            }
            int[] cNnz = new int[h2.getCols()];
            for (int i = 0; i < h2.getCols(); ++i) {
                cNnz[i] = MatrixHistogram.probRound(nnzOut / (double)nnz2 * (double)h2.cNnz[i], rn);
                cMaxNnz = Math.max(cMaxNnz, cNnz[i]);
            }
            return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz);
        }

        private static MatrixHistogram deriveMultHistogram(MatrixHistogram h1, MatrixHistogram h2) {
            double N1 = h1.getNonZeros();
            double N2 = h2.getNonZeros();
            double scaler = IntStream.range(0, h1.getCols()).mapToDouble(j -> (long)h1.cNnz[j] * (long)h2.cNnz[j]).sum();
            double scalec = IntStream.range(0, h1.getRows()).mapToDouble(j -> (long)h1.rNnz[j] * (long)h2.rNnz[j]).sum();
            int rMaxNnz = 0;
            int cMaxNnz = 0;
            Random rn = new Random();
            int[] rNnz = new int[h1.getRows()];
            for (int i = 0; i < h1.getRows(); ++i) {
                rNnz[i] = MatrixHistogram.probRound((double)(h1.rNnz[i] * h2.rNnz[i]) * scaler / N1 / N2, rn);
                rMaxNnz = Math.max(rMaxNnz, rNnz[i]);
            }
            int[] cNnz = new int[h1.getCols()];
            for (int i = 0; i < h1.getCols(); ++i) {
                cNnz[i] = MatrixHistogram.probRound((double)(h1.cNnz[i] * h2.cNnz[i]) * scalec / N1 / N2, rn);
                cMaxNnz = Math.max(cMaxNnz, cNnz[i]);
            }
            return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz);
        }

        private static MatrixHistogram derivePlusHistogram(MatrixHistogram h1, MatrixHistogram h2) {
            double msize = (double)h1.getRows() * (double)h1.getCols();
            int rMaxNnz = 0;
            int cMaxNnz = 0;
            Random rn = new Random();
            int[] rNnz = new int[h1.getRows()];
            for (int i = 0; i < h1.getRows(); ++i) {
                rNnz[i] = MatrixHistogram.probRound((double)h1.rNnz[i] / msize + (double)h2.rNnz[i] / msize - (double)h1.rNnz[i] / msize * (double)h2.rNnz[i] / msize, rn);
                rMaxNnz = Math.max(rMaxNnz, rNnz[i]);
            }
            int[] cNnz = new int[h1.getCols()];
            for (int i = 0; i < h1.getCols(); ++i) {
                cNnz[i] = MatrixHistogram.probRound((double)h1.cNnz[i] / msize + (double)h2.cNnz[i] / msize - (double)h1.cNnz[i] / msize * (double)h2.cNnz[i] / msize, rn);
                cMaxNnz = Math.max(cMaxNnz, cNnz[i]);
            }
            return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz);
        }

        private static MatrixHistogram deriveRbindHistogram(MatrixHistogram h1, MatrixHistogram h2) {
            int[] rNnz = ArrayUtils.addAll((int[])h1.rNnz, (int[])h2.rNnz);
            int rMaxNnz = Math.max(h1.rMaxNnz, h2.rMaxNnz);
            int[] cNnz = new int[h1.getCols()];
            int cMaxNnz = 0;
            for (int i = 0; i < h1.getCols(); ++i) {
                cNnz[i] = h1.cNnz[i] + h2.cNnz[i];
                cMaxNnz = Math.max(cMaxNnz, cNnz[i]);
            }
            return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz);
        }

        private static MatrixHistogram deriveCbindHistogram(MatrixHistogram h1, MatrixHistogram h2) {
            int[] rNnz = new int[h1.getRows()];
            int rMaxNnz = 0;
            for (int i = 0; i < h1.getRows(); ++i) {
                rNnz[i] = h1.rNnz[i] + h2.rNnz[i];
                rMaxNnz = Math.max(rMaxNnz, rNnz[i]);
            }
            int[] cNnz = ArrayUtils.addAll((int[])h1.cNnz, (int[])h2.cNnz);
            int cMaxNnz = Math.max(h1.cMaxNnz, h2.cMaxNnz);
            return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz);
        }

        private static int probRound(double inNnz, Random rand) {
            double randf;
            double temp = Math.floor(inNnz);
            double f = inNnz - temp;
            return (int)(f > (randf = rand.nextDouble()) ? temp + 1.0 : temp);
        }

        private static int[] deriveSummaryStatistics(int[] counts, int N) {
            int max = Integer.MIN_VALUE;
            int N2 = N / 2;
            int cntN1 = 0;
            int cntNeq0 = 0;
            int cntNdiv2 = 0;
            for (int i = 0; i < counts.length; ++i) {
                int cnti = counts[i];
                max = Math.max(max, cnti);
                cntN1 += cnti == 1 ? 1 : 0;
                cntNeq0 += cnti != 0 ? 1 : 0;
                cntNdiv2 += cnti > N2 ? 1 : 0;
            }
            return new int[]{max, cntN1, cntNeq0, cntNdiv2};
        }
    }
}

