/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.BitmapLossy;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.utils.MemoryEstimates;

public class QDictionary
extends ADictionary {
    protected static final Log LOG = LogFactory.getLog((String)QDictionary.class.getName());
    protected double _scale;
    protected byte[] _values;

    public QDictionary(BitmapLossy bm) {
        this._values = bm.getValues();
        this._scale = bm.getScale();
    }

    private QDictionary(byte[] values, double scale) {
        this._values = values;
        this._scale = scale;
    }

    @Override
    public double[] getValues() {
        double[] res = new double[this._values.length];
        for (int i = 0; i < this._values.length; ++i) {
            res[i] = this.getValue(i);
        }
        return res;
    }

    @Override
    public double getValue(int i) {
        return i == this._values.length ? 0.0 : (double)this._values[i] * this._scale;
    }

    public byte getValueByte(int i) {
        return this._values[i];
    }

    public byte[] getValuesByte() {
        return this._values;
    }

    public double getScale() {
        return this._scale;
    }

    @Override
    public long getInMemorySize() {
        return QDictionary.getInMemorySize(this._values.length);
    }

    public static long getInMemorySize(int valuesCount) {
        return 16L + MemoryEstimates.byteArrayCost(valuesCount) + 8L;
    }

    @Override
    public int hasZeroTuple(int ncol) {
        int len = this._values.length / ncol;
        int i = 0;
        int off = 0;
        while (i < len) {
            boolean allZeros = true;
            for (int j = 0; j < ncol; ++j) {
                allZeros &= this._values[off + j] == 0;
            }
            if (allZeros) {
                return i;
            }
            ++i;
            off += ncol;
        }
        return -1;
    }

    @Override
    public double aggregate(double init, Builtin fn) {
        int len = this._values.length;
        double ret = init;
        for (int i = 0; i < len; ++i) {
            ret = fn.execute(ret, this.getValue(i));
        }
        return ret;
    }

    @Override
    public QDictionary apply(ScalarOperator op) {
        if (op.fn instanceof Multiply) {
            this._scale = op.executeScalar(this._scale);
            return this;
        }
        if (op.fn instanceof Plus) {
            double max = Math.max(Math.abs(op.executeScalar(-127.0 * this._scale)), Math.abs(op.executeScalar(127.0 * this._scale)));
            double oldScale = this._scale;
            this._scale = max / 127.0;
            for (int i = 0; i < this._values.length; ++i) {
                this._values[i] = (byte)Math.round(op.executeScalar((double)this._values[i] * oldScale) / this._scale);
            }
        } else {
            int i;
            double max;
            double[] temp = new double[this._values.length];
            temp[0] = max = op.executeScalar(this.getValue(0));
            for (i = 1; i < this._values.length; ++i) {
                temp[i] = op.executeScalar(this.getValue(i));
                double absTemp = Math.abs(temp[i]);
                if (!(absTemp > max)) continue;
                max = absTemp;
            }
            this._scale = max / 127.0;
            for (i = 0; i < this._values.length; ++i) {
                this._values[i] = (byte)Math.round(temp[i] / this._scale);
            }
        }
        return this;
    }

    @Override
    public QDictionary applyScalarOp(ScalarOperator op, double newVal, int numCols) {
        double[] temp = this.getValues();
        double max = newVal;
        for (int i = 0; i < this._values.length; ++i) {
            temp[i] = op.executeScalar(temp[i]);
            double absTemp = Math.abs(temp[i]);
            if (!(absTemp > max)) continue;
            max = absTemp;
        }
        double scale = max / 127.0;
        byte[] res = new byte[this._values.length + numCols];
        for (int i = 0; i < this._values.length; ++i) {
            res[i] = (byte)Math.round(temp[i] / scale);
        }
        Arrays.fill(res, this._values.length, this._values.length + numCols, (byte)Math.round(newVal / scale));
        return new QDictionary(res, scale);
    }

    @Override
    public int getValuesLength() {
        return this._values.length;
    }

    @Override
    public QDictionary clone() {
        return new QDictionary((byte[])this._values.clone(), this._scale);
    }

    public static QDictionary read(DataInput in) throws IOException {
        double scale = in.readDouble();
        int numVals = in.readInt();
        byte[] values = new byte[numVals];
        for (int i = 0; i < numVals; ++i) {
            values[i] = in.readByte();
        }
        return new QDictionary(values, scale);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeDouble(this._scale);
        out.writeInt(this._values.length);
        for (int i = 0; i < this._values.length; ++i) {
            out.writeByte(this._values[i]);
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        return 12 + this._values.length;
    }

    @Override
    public int getNumberOfValues(int nCol) {
        return this._values.length / nCol;
    }

    @Override
    protected double[] sumAllRowsToDouble(KahanFunction kplus, KahanObject kbuff, int nrColumns) {
        if (nrColumns == 1 && kplus instanceof KahanPlus) {
            return this.getValues();
        }
        int numVals = this._values.length / nrColumns;
        double[] ret = ColGroupValue.allocDVector(numVals, false);
        for (int k = 0; k < numVals; ++k) {
            ret[k] = this.sumRow(k, kplus, kbuff, nrColumns);
        }
        return ret;
    }

    @Override
    protected double sumRow(int k, KahanFunction kplus, KahanObject kbuff, int nrColumns) {
        int valOff = k * nrColumns;
        if (kplus instanceof KahanPlus) {
            short res = 0;
            for (int i = 0; i < nrColumns; ++i) {
                res = (short)(res + this._values[valOff + i]);
            }
            return (double)res * this._scale;
        }
        kbuff.set(0.0, 0.0);
        for (int i = 0; i < nrColumns; ++i) {
            kplus.execute2(kbuff, (double)this._values[valOff + i] * this._scale);
        }
        return kbuff._sum;
    }

    @Override
    protected void colSum(double[] c, int[] counts, int[] colIndexes, KahanFunction kplus) {
        if (!(kplus instanceof KahanPlusSq)) {
            int[] sum = new int[colIndexes.length];
            int k = 0;
            int valOff = 0;
            while (k < this._values.length) {
                int cntk = counts[k];
                for (int j = 0; j < colIndexes.length; ++j) {
                    int n = j;
                    sum[n] = sum[n] + cntk * this.getValueByte(valOff + j);
                }
                ++k;
                valOff += colIndexes.length;
            }
            for (int j = 0; j < colIndexes.length; ++j) {
                c[colIndexes[j]] = c[colIndexes[j]] + (double)sum[j] * this._scale;
            }
        } else {
            KahanObject kbuff = new KahanObject(0.0, 0.0);
            int k = 0;
            int valOff = 0;
            while (k < this._values.length) {
                int cntk = counts[k];
                for (int j = 0; j < colIndexes.length; ++j) {
                    kbuff.set(c[colIndexes[j]], c[colIndexes[j] + colIndexes.length]);
                    kplus.execute3(kbuff, this.getValue(valOff + j), cntk);
                    c[colIndexes[j]] = kbuff._sum;
                    c[colIndexes[j] + colIndexes.length] = kbuff._correction;
                }
                ++k;
                valOff += colIndexes.length;
            }
        }
    }

    @Override
    protected double sum(int[] counts, int ncol, KahanFunction kplus) {
        if (!(kplus instanceof KahanPlusSq)) {
            int sum = 0;
            int k = 0;
            int valOff = 0;
            while (k < this._values.length) {
                int cntk = counts[k];
                for (int j = 0; j < ncol; ++j) {
                    sum += cntk * this.getValueByte(valOff + j);
                }
                ++k;
                valOff += ncol;
            }
            return (double)sum * this._scale;
        }
        KahanObject kbuff = new KahanObject(0.0, 0.0);
        int k = 0;
        int valOff = 0;
        while (k < this._values.length) {
            int cntk = counts[k];
            for (int j = 0; j < ncol; ++j) {
                kplus.execute3(kbuff, this.getValue(valOff + j), cntk);
            }
            ++k;
            valOff += ncol;
        }
        return kbuff._sum;
    }
}

