/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

public class ColGroupConst
extends ColGroupCompressed {
    private static final long serialVersionUID = -7387793538322386611L;
    protected ADictionary _dict;

    protected ColGroupConst() {
    }

    public ColGroupConst(int[] colIndices, ADictionary dict) {
        super(colIndices);
        this._dict = dict;
    }

    @Override
    protected void computeRowSums(double[] c, boolean square, int rl, int ru) {
        double vals = this._dict.sumAllRowsToDouble(square, this._colIndexes.length)[0];
        int rix = rl;
        while (rix < ru) {
            int n = rix++;
            c[n] = c[n] + vals;
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru) {
        double value = this._dict.aggregateTuples(builtin, this._colIndexes.length)[0];
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], value);
        }
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.CONST;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.CONST;
    }

    @Override
    public void decompressToBlock(MatrixBlock target, int rl, int ru, int offT) {
        DenseBlock db = target.getDenseBlock();
        int i = rl;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT);
            for (int j = 0; j < this._colIndexes.length; ++j) {
                int n = off + this._colIndexes[j];
                c[n] = c[n] + this._dict.getValue(j);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public double get(int r, int c) {
        return this._dict.getValue(Arrays.binarySearch(this._colIndexes, c));
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        return new ColGroupConst(this._colIndexes, this._dict.clone().apply(op));
    }

    @Override
    public AColGroup binaryRowOp(BinaryOperator op, double[] v, boolean sparseSafe, boolean left) {
        return new ColGroupConst(this._colIndexes, this._dict.clone().applyBinaryRowOp(op, v, true, this._colIndexes, left));
    }

    @Override
    public void countNonZerosPerRow(int[] rnnz, int rl, int ru) {
        int i;
        double[] values = this._dict.getValues();
        int base = 0;
        for (i = 0; i < values.length; ++i) {
            base += values[i] == 0.0 ? 0 : 1;
        }
        for (i = 0; i < ru - rl; ++i) {
            rnnz[i] = base;
        }
    }

    public void addToCommon(double[] constV) {
        double[] values = this._dict.getValues();
        if (values != null && constV != null) {
            for (int i = 0; i < this._colIndexes.length; ++i) {
                int n = this._colIndexes[i];
                constV[n] = constV[n] + values[i];
            }
        }
    }

    @Override
    public double[] getValues() {
        return this._dict != null ? this._dict.getValues() : null;
    }

    @Override
    public final boolean isLossy() {
        return this._dict.isLossy();
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregate(c, builtin);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateCols(c, builtin, this._colIndexes);
    }

    @Override
    protected void computeSum(double[] c, int nRows, boolean square) {
        if (this._dict != null) {
            c[0] = square ? c[0] + this._dict.sumsq(new int[]{nRows}, this._colIndexes.length) : c[0] + this._dict.sum(new int[]{nRows}, this._colIndexes.length);
        }
    }

    @Override
    protected void computeColSums(double[] c, int nRows, boolean square) {
        this._dict.colSum(c, new int[]{nRows}, this._colIndexes, square);
    }

    @Override
    public int getNumValues() {
        return 1;
    }

    @Override
    public MatrixBlock getValuesAsBlock() {
        this._dict = this._dict.getAsMatrixBlockDictionary(this._colIndexes.length);
        MatrixBlock ret = ((MatrixBlockDictionary)this._dict).getMatrixBlock();
        return ret;
    }

    @Override
    public AColGroup rightMultByMatrix(MatrixBlock right) {
        if (right.isEmpty()) {
            return null;
        }
        int rr = right.getNumRows();
        int cr = right.getNumColumns();
        if (this._colIndexes.length == rr) {
            MatrixBlock left = this.getValuesAsBlock();
            MatrixBlock ret = new MatrixBlock(1, cr, false);
            LibMatrixMult.matrixMult(left, right, ret);
            MatrixBlockDictionary d = new MatrixBlockDictionary(ret);
            return ColGroupFactory.getColGroupConst(cr, d);
        }
        throw new NotImplementedException();
    }

    @Override
    public void tsmm(double[] result, int numColumns, int nRows) {
        ColGroupConst.tsmm(result, numColumns, new int[]{nRows}, this._dict, this._colIndexes);
    }

    @Override
    public void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result, int rl, int ru) {
        throw new NotImplementedException();
    }

    @Override
    public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    public boolean isDense() {
        return true;
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        int[] colIndexes = new int[]{0};
        double v = this._dict.getValue(idx);
        if (v == 0.0) {
            return new ColGroupEmpty(colIndexes);
        }
        Dictionary retD = new Dictionary(new double[]{this._dict.getValue(idx)});
        return new ColGroupConst(colIndexes, retD);
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ADictionary retD = this._dict.sliceOutColumnRange(idStart, idEnd, this._colIndexes.length);
        return new ColGroupConst(outputCols, retD);
    }

    @Override
    public AColGroup copy() {
        return new ColGroupConst(this._colIndexes, this._dict.clone());
    }

    @Override
    public boolean containsValue(double pattern) {
        return this._dict.containsValue(pattern);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        return this._dict.getNumberNonZeros(new int[]{nRows}, this._colIndexes.length);
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        ADictionary replaced = this._dict.replace(pattern, replace, this._colIndexes.length);
        return new ColGroupConst(this._colIndexes, replaced);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._dict = DictionaryFactory.read(in);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._dict.write(out);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        if (this._dict != null) {
            ret += this._dict.getExactSizeOnDisk();
        }
        return ret;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s ", "Values: " + this._dict.getClass().getSimpleName()));
        sb.append(this._dict.getString(this._colIndexes.length));
        return sb.toString();
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        double[] vals = this.getValues();
        for (int i = 0; i < this._colIndexes.length; ++i) {
            double v = vals[i];
            c[0] = v != 0.0 ? c[0] * Math.pow(v, nRows) : 0.0;
        }
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru) {
        throw new NotImplementedException();
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        throw new NotImplementedException();
    }
}

