/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupIO;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSizes;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.utils.MemoryEstimates;

public class ColGroupLinearFunctional
extends AColGroupCompressed {
    private static final long serialVersionUID = -2811822570758221975L;
    private static final double CONTAINS_VALUE_THRESHOLD = 1.0E-6;
    protected double[] _coefficents;
    protected int _numRows;

    private ColGroupLinearFunctional(IColIndex colIndices, double[] coefficents, int numRows) {
        super(colIndices);
        this._coefficents = coefficents;
        this._numRows = numRows;
    }

    public static AColGroup create(IColIndex colIndices, double[] coefficents, int numRows) {
        if (coefficents.length != 2 * colIndices.size()) {
            throw new DMLCompressionException("Invalid size of values compared to columns");
        }
        boolean allSlopesConstant = true;
        for (int j = 0; j < colIndices.size(); ++j) {
            if (coefficents[colIndices.size() + j] == 0.0) continue;
            allSlopesConstant = false;
            break;
        }
        if (allSlopesConstant) {
            boolean allInterceptsZero = true;
            for (int j = 0; j < colIndices.size(); ++j) {
                if (coefficents[j] == 0.0) continue;
                allInterceptsZero = false;
                break;
            }
            if (allInterceptsZero) {
                return new ColGroupEmpty(colIndices);
            }
            double[] intercepts = new double[colIndices.size()];
            System.arraycopy(coefficents, 0, intercepts, 0, colIndices.size());
            return ColGroupConst.create(colIndices, intercepts);
        }
        return new ColGroupLinearFunctional(colIndices, coefficents, numRows);
    }

    public double getInterceptForColumn(int colIdx) {
        return this._coefficents[colIdx];
    }

    public double getSlopeForColumn(int colIdx) {
        return this._coefficents[this._colIndexes.size() + colIdx];
    }

    public int getNumRows() {
        return this._numRows;
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.LinearFunctional;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.LinearFunctional;
    }

    @Override
    public double getMin() {
        double min = Double.POSITIVE_INFINITY;
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            if (slope >= 0.0 && intercept + slope < min) {
                min = intercept + slope;
                continue;
            }
            if (!(slope < 0.0) || !(intercept + (double)this._numRows * slope < min)) continue;
            min = intercept + (double)this._numRows * slope;
        }
        return min;
    }

    @Override
    public double getMax() {
        double max = Double.NEGATIVE_INFINITY;
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            if (slope >= 0.0 && intercept + (double)this._numRows * slope > max) {
                max = intercept + (double)this._numRows * slope;
                continue;
            }
            if (!(slope < 0.0) || !(intercept + slope > max)) continue;
            max = intercept + slope;
        }
        return max;
    }

    @Override
    public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) {
        int nCol = this.getNumCols();
        double[] accumulators = new double[nCol];
        System.arraycopy(this._coefficents, 0, accumulators, 0, nCol);
        int offT = rl + offR;
        int row = rl;
        while (row < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            for (int j = 0; j < nCol; ++j) {
                int n = j;
                accumulators[n] = accumulators[n] + this.getSlopeForColumn(j);
                int n2 = off + this._colIndexes.get(j);
                c[n2] = c[n2] + accumulators[j];
            }
            ++row;
            ++offT;
        }
    }

    @Override
    public void decompressToSparseBlock(SparseBlock ret, int rl, int ru, int offR, int offC) {
        int nCol = this._colIndexes.size();
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            for (int j = 0; j < nCol; ++j) {
                ret.append(offT, this._colIndexes.get(j) + offC, this.getIdx(i, j));
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this.getInterceptForColumn(colIdx) + this.getSlopeForColumn(colIdx) * (double)(r + 1);
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        double[] coefficients_new = new double[this._coefficents.length];
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            System.arraycopy(this._coefficents, 0, coefficients_new, this.getNumCols(), this.getNumCols());
            for (int col = 0; col < this.getNumCols(); ++col) {
                coefficients_new[col] = op.executeScalar(this._coefficents[col]);
            }
            return ColGroupLinearFunctional.create(this._colIndexes, coefficients_new, this._numRows);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            for (int j = 0; j < this._coefficents.length; ++j) {
                coefficients_new[j] = op.executeScalar(this._coefficents[j]);
            }
            return ColGroupLinearFunctional.create(this._colIndexes, coefficients_new, this._numRows);
        }
        throw new NotImplementedException();
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        return this.binaryRowOp(op, v, isRowSafe, true);
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        return this.binaryRowOp(op, v, isRowSafe, false);
    }

    private AColGroup binaryRowOp(BinaryOperator op, double[] v, boolean isRowSafe, boolean left) {
        double[] coefficients_new = new double[this._coefficents.length];
        if (op.fn instanceof Plus || op.fn instanceof Minus) {
            System.arraycopy(this._coefficents, 0, coefficients_new, this.getNumCols(), this.getNumCols());
            if (left) {
                for (int col = 0; col < this.getNumCols(); ++col) {
                    coefficients_new[col] = op.fn.execute(v[this._colIndexes.get(col)], this._coefficents[col]);
                }
            } else {
                for (int col = 0; col < this.getNumCols(); ++col) {
                    coefficients_new[col] = op.fn.execute(this._coefficents[col], v[this._colIndexes.get(col)]);
                }
            }
            return ColGroupLinearFunctional.create(this._colIndexes, coefficients_new, this._numRows);
        }
        if (op.fn instanceof Multiply || op.fn instanceof Divide) {
            if (left) {
                for (int col = 0; col < this.getNumCols(); ++col) {
                    coefficients_new[col] = op.fn.execute(v[this._colIndexes.get(col)], this._coefficents[col]);
                    coefficients_new[col + this.getNumCols()] = op.fn.execute(v[this._colIndexes.get(col)], this._coefficents[col + this.getNumCols()]);
                }
            } else {
                for (int col = 0; col < this.getNumCols(); ++col) {
                    coefficients_new[col] = op.fn.execute(this._coefficents[col], v[this._colIndexes.get(col)]);
                    coefficients_new[col + this.getNumCols()] = op.fn.execute(this._coefficents[col + this.getNumCols()], v[this._colIndexes.get(col)]);
                }
            }
            return ColGroupLinearFunctional.create(this._colIndexes, coefficients_new, this._numRows);
        }
        throw new NotImplementedException();
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        throw new NotImplementedException();
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        throw new NotImplementedException();
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            c[0] = c[0] + (double)nRows * (intercept + (double)(nRows + 1) * slope / 2.0);
        }
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            int n = this._colIndexes.get(col);
            c[n] = c[n] + (double)nRows * (intercept + (double)(nRows + 1) * slope / 2.0);
        }
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            c[0] = c[0] + (double)nRows * (Math.pow(intercept, 2.0) + (double)(nRows + 1) * slope * intercept + (double)((nRows + 1) * (2 * nRows + 1)) * Math.pow(slope, 2.0) / 6.0);
        }
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            int n = this._colIndexes.get(col);
            c[n] = c[n] + (double)nRows * (Math.pow(intercept, 2.0) + (double)(nRows + 1) * slope * intercept + (double)((nRows + 1) * (2 * nRows + 1)) * Math.pow(slope, 2.0) / 6.0);
        }
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        double intercept_sum = preAgg[0];
        double slope_sum = preAgg[1];
        for (int rix = rl; rix < ru; ++rix) {
            int n = rix;
            c[n] = c[n] + (intercept_sum + slope_sum * (double)(rix + 1));
        }
    }

    @Override
    public int getNumValues() {
        return 0;
    }

    @Override
    public AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols) {
        int nColR = right.getNumColumns();
        IColIndex outputCols = allCols != null ? allCols : ColIndexFactory.create(nColR);
        MatrixBlock result = new MatrixBlock(this._numRows, nColR, false);
        for (int j = 0; j < nColR; ++j) {
            double bias_accum = 0.0;
            double slope_accum = 0.0;
            for (int c = 0; c < this._colIndexes.size(); ++c) {
                bias_accum += right.getValue(this._colIndexes.get(c), j) * this.getInterceptForColumn(c);
                slope_accum += right.getValue(this._colIndexes.get(c), j) * this.getSlopeForColumn(c);
            }
            for (int r = 0; r < this._numRows; ++r) {
                result.setValue(r, j, bias_accum + (double)(r + 1) * slope_accum);
            }
        }
        return ColGroupUncompressed.create(result, outputCols);
    }

    @Override
    public void tsmm(double[] ret, int numColumns, int nRows) {
        int tCol = this._colIndexes.size();
        double sumIndices = (double)(nRows * (nRows + 1)) / 2.0;
        double sumSquaredIndices = (double)(nRows * (nRows + 1) * (2 * nRows + 1)) / 6.0;
        for (int row = 0; row < tCol; ++row) {
            double alpha1 = (double)nRows * this.getInterceptForColumn(row) + sumIndices * this.getSlopeForColumn(row);
            double alpha2 = sumIndices * this.getInterceptForColumn(row) + sumSquaredIndices * this.getSlopeForColumn(row);
            int offRet = this._colIndexes.get(row) * numColumns;
            for (int col = row; col < tCol; ++col) {
                int n = offRet + this._colIndexes.get(col);
                ret[n] = ret[n] + (alpha1 * this.getInterceptForColumn(col) + alpha2 * this.getSlopeForColumn(col));
            }
        }
    }

    @Override
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        throw new DMLCompressionException("This method should never be called");
    }

    @Override
    public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) {
        if (lhs instanceof ColGroupEmpty) {
            return;
        }
        MatrixBlock tmpRet = new MatrixBlock(lhs.getNumCols(), this._colIndexes.size(), 0L);
        if (lhs instanceof ColGroupUncompressed) {
            ColGroupUncompressed lhsUC = (ColGroupUncompressed)lhs;
            int numRowsLeft = lhsUC.getData().getNumRows();
            double[] colSumsAndWeightedColSums = new double[2 * lhs.getNumCols()];
            int j = 0;
            int offTmp = 0;
            while (j < lhs.getNumCols()) {
                for (int i = 0; i < numRowsLeft; ++i) {
                    int n = offTmp;
                    colSumsAndWeightedColSums[n] = colSumsAndWeightedColSums[n] + lhs.getIdx(i, j);
                    int n2 = offTmp + 1;
                    colSumsAndWeightedColSums[n2] = colSumsAndWeightedColSums[n2] + (double)(i + 1) * lhs.getIdx(i, j);
                }
                ++j;
                offTmp += 2;
            }
            MatrixBlock sumMatrix = new MatrixBlock(lhs.getNumCols(), 2, colSumsAndWeightedColSums);
            MatrixBlock coefficientMatrix = new MatrixBlock(2, this._colIndexes.size(), this._coefficents);
            LibMatrixMult.matrixMult(sumMatrix, coefficientMatrix, tmpRet);
        } else if (lhs instanceof ColGroupLinearFunctional) {
            ColGroupLinearFunctional lhsLF = (ColGroupLinearFunctional)lhs;
            double sumIndices = (double)(this._numRows * (this._numRows + 1)) / 2.0;
            double sumSquaredIndices = (double)(this._numRows * (this._numRows + 1) * (2 * this._numRows + 1)) / 6.0;
            MatrixBlock weightMatrix = new MatrixBlock(2, 2, new double[]{this._numRows, sumIndices, sumIndices, sumSquaredIndices});
            MatrixBlock coefficientMatrixLhs = new MatrixBlock(2, lhsLF._colIndexes.size(), lhsLF._coefficents);
            MatrixBlock coefficientMatrixRhs = new MatrixBlock(2, this._colIndexes.size(), this._coefficents);
            coefficientMatrixLhs = LibMatrixReorg.transposeInPlace(coefficientMatrixLhs, InfrastructureAnalyzer.getLocalParallelism());
            MatrixBlock tmp = new MatrixBlock(lhs.getNumCols(), 2, false);
            LibMatrixMult.matrixMult(coefficientMatrixLhs, weightMatrix, tmp);
            LibMatrixMult.matrixMult(tmp, coefficientMatrixRhs, tmpRet);
        } else {
            if (lhs instanceof APreAgg) {
                throw new NotImplementedException();
            }
            throw new NotImplementedException();
        }
        ColGroupUtils.copyValuesColGroupMatrixBlocks(lhs, this, tmpRet, result);
    }

    @Override
    public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        throw new NotImplementedException();
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) {
        throw new NotImplementedException();
    }

    @Override
    public boolean containsValue(double pattern) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            if (!this.colContainsValue(col, pattern)) continue;
            return true;
        }
        return false;
    }

    public boolean colContainsValue(int col, double pattern) {
        if (pattern == this.getInterceptForColumn(col)) {
            return Math.abs(this.getSlopeForColumn(col)) < 1.0E-6;
        }
        double div = (pattern - this.getInterceptForColumn(col)) / this.getSlopeForColumn(col);
        double diffToNextInt = Math.min(Math.ceil(div) - div, div - Math.floor(div));
        return Math.abs(diffToNextInt) < 1.0E-6;
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        throw new NotImplementedException();
    }

    public static ColGroupLinearFunctional read(DataInput in, int nRows) throws IOException {
        IColIndex cols = ColIndexFactory.read(in);
        double[] coefficients = ColGroupIO.readDoubleArray(2 * cols.size(), in);
        return new ColGroupLinearFunctional(cols, coefficients, nRows);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        for (double d : this._coefficents) {
            out.writeDouble(d);
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ret = (long)((double)ret + MemoryEstimates.doubleArrayCost(this._coefficents.length));
        return ret += 4L;
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        if (this.containsValue(0.0)) {
            c[0] = 0.0;
            return;
        }
        for (int col = 0; col < this.getNumCols(); ++col) {
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            for (int i = 0; i < nRows; ++i) {
                c[0] = c[0] * (intercept + slope * (double)(i + 1));
            }
        }
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        for (int rix = rl; rix < ru; ++rix) {
            for (int col = 0; col < this.getNumCols(); ++col) {
                double intercept = this.getInterceptForColumn(col);
                double slope = this.getSlopeForColumn(col);
                int n = rix;
                c[n] = c[n] * (intercept + slope * (double)(rix + 1));
            }
        }
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        for (int col = 0; col < this.getNumCols(); ++col) {
            if (this.colContainsValue(col, 0.0)) {
                c[this._colIndexes.get((int)col)] = 0.0;
                continue;
            }
            double intercept = this.getInterceptForColumn(col);
            double slope = this.getSlopeForColumn(col);
            for (int i = 0; i < nRows; ++i) {
                int n = this._colIndexes.get(col);
                c[n] = c[n] * (intercept + slope * (double)(i + 1));
            }
        }
    }

    @Override
    protected double[] preAggSumRows() {
        double intercept_sum = 0.0;
        for (int col = 0; col < this.getNumCols(); ++col) {
            intercept_sum += this.getInterceptForColumn(col);
        }
        double slope_sum = 0.0;
        for (int col = 0; col < this.getNumCols(); ++col) {
            slope_sum += this.getSlopeForColumn(col);
        }
        return new double[]{intercept_sum, slope_sum};
    }

    @Override
    protected double[] preAggSumSqRows() {
        return null;
    }

    @Override
    protected double[] preAggProductRows() {
        return null;
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        throw new NotImplementedException();
    }

    @Override
    public long estimateInMemorySize() {
        return ColGroupSizes.estimateInMemorySizeLinearFunctional(this.getNumCols(), this._colIndexes.isContiguous());
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        throw new NotImplementedException();
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        LOG.warn((Object)"Cost calculation for LinearFunctional ColGroup is not precise");
        int nCols = this.getNumCols();
        return e.getCost(nRows, nRows, nCols, 2, 1.0);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s", " Intercepts: " + Arrays.toString(this.getIntercepts())));
        sb.append(String.format("\n%15s", " Slopes: " + Arrays.toString(this.getSlopes())));
        return sb.toString();
    }

    public double[] getIntercepts() {
        double[] intercepts = new double[this.getNumCols()];
        for (int col = 0; col < this.getNumCols(); ++col) {
            intercepts[col] = this.getInterceptForColumn(col);
        }
        return intercepts;
    }

    public double[] getSlopes() {
        double[] slopes = new double[this.getNumCols()];
        for (int col = 0; col < this.getNumCols(); ++col) {
            slopes[col] = this.getSlopeForColumn(col);
        }
        return slopes;
    }

    @Override
    public AColGroup sliceRows(int rl, int ru) {
        throw new NotImplementedException();
    }

    @Override
    public AColGroup copyAndSet(IColIndex colIndexes) {
        return ColGroupLinearFunctional.create(colIndexes, this._coefficents, this._numRows);
    }

    @Override
    public AColGroup append(AColGroup g) {
        return null;
    }

    @Override
    public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) {
        throw new NotImplementedException();
    }

    @Override
    public ICLAScheme getCompressionScheme() {
        return null;
    }

    @Override
    public AColGroup recompress() {
        return this;
    }

    @Override
    public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
        throw new NotImplementedException("Not Implemented Compressed SizeInfo for Linear col group");
    }

    @Override
    public boolean sameIndexStructure(AColGroupCompressed that) {
        throw new NotImplementedException();
    }

    @Override
    protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
        throw new NotImplementedException();
    }
}

