/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.estim;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

public class EstimatorSample
extends SparsityEstimator {
    private static final double SAMPLE_FRACTION = 0.1;
    private final double _frac;
    private final boolean _extended;

    public EstimatorSample() {
        this(0.1, false);
    }

    public EstimatorSample(double sampleFrac) {
        this(sampleFrac, false);
    }

    public EstimatorSample(double sampleFrac, boolean extended) {
        if (sampleFrac <= 0.0 || sampleFrac > 1.0) {
            throw new DMLRuntimeException("Invalid sample fraction: " + sampleFrac);
        }
        this._frac = sampleFrac;
        this._extended = extended;
    }

    @Override
    public DataCharacteristics estim(MMNode root) {
        LOG.warn((Object)"Recursive estimates not supported by EstimatorSample, falling back to EstimatorBasicAvg.");
        return new EstimatorBasicAvg().estim(root);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        return this.estim(m1, m2, SparsityEstimator.OpCode.MM);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        switch (op) {
            case MM: {
                int k = m1.getNumColumns();
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                int p = ix.length;
                int[] cnnz = EstimatorSample.computeColumnNnz(m1, ix);
                if (this._extended) {
                    double ml = (long)m1.getNumRows() * (long)m2.getNumColumns();
                    double sumS = 0.0;
                    double prodS = 1.0;
                    for (int i = 0; i < ix.length; ++i) {
                        long rnnz = m2.recomputeNonZeros(ix[i], ix[i]);
                        double v = (double)cnnz[i] * (double)rnnz / ml;
                        sumS += v;
                        prodS *= 1.0 - v;
                    }
                    return 1.0 - Math.pow(1.0 - 1.0 / (double)p * sumS, k - p) * prodS;
                }
                long nnzOut = 0L;
                for (int i = 0; i < p; ++i) {
                    nnzOut = Math.max(nnzOut, (long)cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i]));
                }
                return OptimizerUtils.getSparsity(m1.getNumRows(), m2.getNumColumns(), nnzOut);
            }
            case MULT: {
                int k = Math.max(m1.getNumColumns(), m1.getNumRows());
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                double spOut = 0.0;
                if (m1.getNumColumns() > m1.getNumRows()) {
                    int[] cnnz1 = EstimatorSample.computeColumnNnz(m1, ix);
                    int[] cnnz2 = EstimatorSample.computeColumnNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)cnnz1[i] / (double)m1.getNumRows() * ((double)cnnz2[i] / (double)m1.getNumRows());
                    }
                } else {
                    int[] rnnz1 = EstimatorSample.computeRowNnz(m1, ix);
                    int[] rnnz2 = EstimatorSample.computeRowNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)rnnz1[i] / (double)m1.getNumColumns() * ((double)rnnz2[i] / (double)m1.getNumColumns());
                    }
                }
                return spOut / (double)ix.length;
            }
            case PLUS: {
                int k = Math.max(m1.getNumColumns(), m1.getNumRows());
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                double spOut = 0.0;
                if (m1.getNumColumns() > m1.getNumRows()) {
                    int[] cnnz1 = EstimatorSample.computeColumnNnz(m1, ix);
                    int[] cnnz2 = EstimatorSample.computeColumnNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)cnnz1[i] / (double)m1.getNumRows() + (double)cnnz2[i] / (double)m1.getNumRows() - (double)cnnz1[i] / (double)m1.getNumRows() * ((double)cnnz2[i] / (double)m1.getNumRows());
                    }
                } else {
                    int[] rnnz1 = EstimatorSample.computeRowNnz(m1, ix);
                    int[] rnnz2 = EstimatorSample.computeRowNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)rnnz1[i] / (double)m1.getNumColumns() + (double)rnnz2[i] / (double)m1.getNumColumns() - (double)rnnz1[i] / (double)m1.getNumColumns() * ((double)rnnz2[i] / (double)m1.getNumColumns());
                    }
                }
                return spOut / (double)ix.length;
            }
            case RBIND: 
            case CBIND: 
            case EQZERO: 
            case NEQZERO: 
            case TRANS: 
            case DIAG: 
            case RESHAPE: {
                DataCharacteristics mc1 = m1.getDataCharacteristics();
                DataCharacteristics mc2 = m2.getDataCharacteristics();
                return OptimizerUtils.getSparsity(this.estimExactMetaData(mc1, mc2, op));
            }
        }
        throw new NotImplementedException();
    }

    @Override
    public double estim(MatrixBlock m, SparsityEstimator.OpCode op) {
        return this.estim(m, null, op);
    }

    private static int[] computeColumnNnz(MatrixBlock in, int[] ix) {
        int i;
        int[] nnz = new int[in.getNumColumns()];
        if (in.isInSparseFormat()) {
            SparseBlock sblock = in.getSparseBlock();
            for (i = 0; i < in.getNumRows(); ++i) {
                if (sblock.isEmpty(i)) continue;
                LibMatrixAgg.countAgg(sblock.values(i), nnz, sblock.indexes(i), sblock.pos(i), sblock.size(i));
            }
        } else {
            DenseBlock dblock = in.getDenseBlock();
            for (i = 0; i < in.getNumRows(); ++i) {
                double[] avals = dblock.values(i);
                int aix = dblock.pos(i);
                for (int j = 0; j < in.getNumColumns(); ++j) {
                    int n = j;
                    nnz[n] = nnz[n] + (avals[aix + j] != 0.0 ? 1 : 0);
                }
            }
        }
        int[] ret = new int[ix.length];
        for (i = 0; i < ix.length; ++i) {
            ret[i] = nnz[ix[i]];
        }
        return ret;
    }

    private static int[] computeRowNnz(MatrixBlock in, int[] ix) {
        int[] ret = new int[ix.length];
        for (int i = 0; i < ix.length; ++i) {
            ret[i] = (int)in.recomputeNonZeros(ix[i], ix[i]);
        }
        return ret;
    }
}

