/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.estim;

import org.apache.commons.lang.NotImplementedException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.estim.EstimatorBasicAvg;
import org.apache.sysml.hops.estim.MMNode;
import org.apache.sysml.hops.estim.SparsityEstimator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.UtilFunctions;

public class EstimatorSample
extends SparsityEstimator {
    private static final double SAMPLE_FRACTION = 0.1;
    private final double _frac;

    public EstimatorSample() {
        this(0.1);
    }

    public EstimatorSample(double sampleFrac) {
        if (sampleFrac < 0.0 || sampleFrac > 1.0) {
            throw new DMLRuntimeException("Invalid sample fraction: " + sampleFrac);
        }
        this._frac = sampleFrac;
    }

    @Override
    public MatrixCharacteristics estim(MMNode root) {
        LOG.warn((Object)"Recursive estimates not supported by EstimatorSample, falling back to EstimatorBasicAvg.");
        return new EstimatorBasicAvg().estim(root);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        return this.estim(m1, m2, SparsityEstimator.OpCode.MM);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        switch (op) {
            case MM: {
                int k = m1.getNumColumns();
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                int[] cnnz = this.computeColumnNnz(m1, ix);
                long nnzOut = 0L;
                for (int i = 0; i < ix.length; ++i) {
                    nnzOut = Math.max(nnzOut, (long)cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i]));
                }
                return OptimizerUtils.getSparsity(m1.getNumRows(), m2.getNumColumns(), nnzOut);
            }
            case MULT: {
                int k = Math.max(m1.getNumColumns(), m1.getNumRows());
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                double spOut = 0.0;
                if (m1.getNumColumns() > m1.getNumRows()) {
                    int[] cnnz1 = this.computeColumnNnz(m1, ix);
                    int[] cnnz2 = this.computeColumnNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)cnnz1[i] / (double)m1.getNumRows() * (double)cnnz2[i] / (double)m1.getNumRows();
                    }
                } else {
                    int[] rnnz1 = this.computeRowNnz(m1, ix);
                    int[] rnnz2 = this.computeRowNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)rnnz1[i] / (double)m1.getNumColumns() * (double)rnnz2[i] / (double)m1.getNumColumns();
                    }
                }
                return spOut / (double)ix.length;
            }
            case PLUS: {
                int k = Math.max(m1.getNumColumns(), m1.getNumRows());
                int[] ix = UtilFunctions.getSortedSampleIndexes(k, (int)Math.max((double)k * this._frac, 1.0));
                double spOut = 0.0;
                if (m1.getNumColumns() > m1.getNumRows()) {
                    int[] cnnz1 = this.computeColumnNnz(m1, ix);
                    int[] cnnz2 = this.computeColumnNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)cnnz1[i] / (double)m1.getNumRows() + (double)cnnz2[i] / (double)m1.getNumRows() - (double)cnnz1[i] / (double)m1.getNumRows() * (double)cnnz2[i] / (double)m1.getNumRows();
                    }
                } else {
                    int[] rnnz1 = this.computeRowNnz(m1, ix);
                    int[] rnnz2 = this.computeRowNnz(m2, ix);
                    for (int i = 0; i < ix.length; ++i) {
                        spOut += (double)rnnz1[i] / (double)m1.getNumColumns() + (double)rnnz2[i] / (double)m1.getNumColumns() - (double)rnnz1[i] / (double)m1.getNumColumns() * (double)rnnz2[i] / (double)m1.getNumColumns();
                    }
                }
                return spOut / (double)ix.length;
            }
            case RBIND: 
            case CBIND: 
            case EQZERO: 
            case NEQZERO: 
            case TRANS: 
            case DIAG: 
            case RESHAPE: {
                MatrixCharacteristics mc1 = m1.getMatrixCharacteristics();
                MatrixCharacteristics mc2 = m2.getMatrixCharacteristics();
                return OptimizerUtils.getSparsity(this.estimExactMetaData(mc1, mc2, op));
            }
        }
        throw new NotImplementedException();
    }

    @Override
    public double estim(MatrixBlock m, SparsityEstimator.OpCode op) {
        return this.estim(m, null, op);
    }

    private int[] computeColumnNnz(MatrixBlock in, int[] ix) {
        int i;
        int[] nnz = new int[in.getNumColumns()];
        if (in.isInSparseFormat()) {
            SparseBlock sblock = in.getSparseBlock();
            for (i = 0; i < in.getNumRows(); ++i) {
                if (sblock.isEmpty(i)) continue;
                LibMatrixAgg.countAgg(sblock.values(i), nnz, sblock.indexes(i), sblock.pos(i), sblock.size(i));
            }
        } else {
            DenseBlock dblock = in.getDenseBlock();
            for (i = 0; i < in.getNumRows(); ++i) {
                double[] avals = dblock.values(i);
                int aix = dblock.pos(i);
                for (int j = 0; j < in.getNumColumns(); ++j) {
                    int n = j;
                    nnz[n] = nnz[n] + (avals[aix + j] != 0.0 ? 1 : 0);
                }
            }
        }
        int[] ret = new int[ix.length];
        for (i = 0; i < ix.length; ++i) {
            ret[i] = nnz[ix[i]];
        }
        return ret;
    }

    private int[] computeRowNnz(MatrixBlock in, int[] ix) {
        int[] ret = new int[ix.length];
        for (int i = 0; i < ix.length; ++i) {
            ret[i] = (int)in.recomputeNonZeros(ix[i], ix[i]);
        }
        return ret;
    }
}

