/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.utils.stats.Timing;

public final class CLALibMMChain {
    static final Log LOG = LogFactory.getLog((String)CLALibMMChain.class.getName());
    private static ThreadLocal<double[]> cacheIntermediate = null;

    private CLALibMMChain() {
    }

    public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, MatrixBlock w, MatrixBlock out, MapMultChain.ChainType ctype, int k) {
        Timing t = new Timing();
        if (x.isEmpty()) {
            return CLALibMMChain.returnEmpty(x, out);
        }
        x = CLALibMMChain.filterColGroups(x);
        double preFilterTime = t.stop();
        boolean allowOverlap = x.getColGroups().size() == 1 && CLALibMMChain.isOverlappingAllowed();
        MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, null, k, true);
        double rmmTime = t.stop();
        if (ctype == MapMultChain.ChainType.XtwXv) {
            tmp = CLALibMMChain.binaryMultW(tmp, w, k);
        }
        if (!allowOverlap && tmp instanceof CompressedMatrixBlock) {
            tmp = CLALibMMChain.decompressIntermediate((CompressedMatrixBlock)tmp, k);
        }
        double decompressTime = t.stop();
        if (tmp instanceof CompressedMatrixBlock) {
            CLALibLeftMultBy.leftMultByMatrixTransposed(x, (CompressedMatrixBlock)tmp, out, k);
        } else {
            CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, out, k);
        }
        double lmmTime = t.stop();
        if (out.getNumColumns() != 1) {
            out = LibMatrixReorg.transposeInPlace(out, k);
        }
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("\n");
            sb.append("\nPreFilter Time      : " + preFilterTime);
            sb.append("\nChain RMM           : " + rmmTime);
            sb.append("\nChain RMM Decompress: " + decompressTime);
            sb.append("\nChain LMM           : " + lmmTime);
            sb.append("\nChain Transpose     : " + t.stop());
            LOG.debug((Object)sb.toString());
        }
        return out;
    }

    private static MatrixBlock decompressIntermediate(CompressedMatrixBlock tmp, int k) {
        double[] tmpArr;
        int rows = tmp.getNumRows();
        int cols = tmp.getNumColumns();
        int nCells = rows * cols;
        if (cacheIntermediate == null) {
            tmpArr = new double[nCells];
            cacheIntermediate = new ThreadLocal();
            cacheIntermediate.set(tmpArr);
        } else {
            double[] cachedArr = cacheIntermediate.get();
            if (cachedArr == null || cachedArr.length < nCells) {
                tmpArr = new double[nCells];
                cacheIntermediate.set(tmpArr);
            } else {
                tmpArr = cachedArr;
            }
        }
        MatrixBlock tmpV = new MatrixBlock(tmp.getNumRows(), tmp.getNumColumns(), tmpArr);
        CLALibDecompress.decompressTo(tmp, tmpV, 0, 0, k, false, true);
        return tmpV;
    }

    private static boolean isOverlappingAllowed() {
        return ConfigurationManager.getDMLConfig().getBooleanValue("sysds.compressed.overlapping");
    }

    private static MatrixBlock returnEmpty(CompressedMatrixBlock x, MatrixBlock out) {
        out = CLALibMMChain.prepareReturn(x, out);
        return out;
    }

    private static MatrixBlock prepareReturn(CompressedMatrixBlock x, MatrixBlock out) {
        int clen = x.getNumColumns();
        if (out != null) {
            out.reset(clen, 1, false);
        } else {
            out = new MatrixBlock(clen, 1, false);
        }
        return out;
    }

    private static MatrixBlock binaryMultW(MatrixBlock tmp, MatrixBlock w, int k) {
        BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject(), k);
        if (tmp instanceof CompressedMatrixBlock) {
            tmp = CLALibBinaryCellOp.binaryOperationsRight(bop, (CompressedMatrixBlock)tmp, w);
        } else {
            LibMatrixBincell.bincellOpInPlace(tmp, w, bop);
        }
        return tmp;
    }

    private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock x) {
        List<AColGroup> groups = x.getColGroups();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
        if (shouldFilter) {
            if (CLALibUtils.alreadyPreFiltered(groups, x.getNumColumns())) {
                return x;
            }
            int nCol = x.getNumColumns();
            double[] constV = new double[nCol];
            List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
            AColGroup c = ColGroupConst.create(constV);
            filteredGroups.add(c);
            x.allocateColGroupList(filteredGroups);
            return x;
        }
        return x;
    }
}

