/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.UnaryFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class ReorgFEDInstruction
extends UnaryFEDInstruction {
    public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FEDInstruction.FederatedOutput fedOut) {
        super(FEDInstruction.FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
    }

    public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.Reorg, op, in1, out, opcode, istr);
    }

    public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) {
        return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), rinst.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) {
        return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), rinst.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static ReorgFEDInstruction parseInstruction(String str) {
        CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("r'")) {
            InstructionUtils.checkNumFields(str, 2, 3, 4);
            in.split(parts[1]);
            out.split(parts[2]);
            int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(parts[3]);
            FEDInstruction.FederatedOutput fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ? FEDInstruction.FederatedOutput.valueOf(parts[3]) : FEDInstruction.FederatedOutput.valueOf(parts[4]);
            return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
        }
        if (opcode.equalsIgnoreCase("rdiag")) {
            ReorgFEDInstruction.parseUnaryInstruction(str, in, out);
            FEDInstruction.FederatedOutput fedOut = ReorgFEDInstruction.parseFedOutFlag(str, 3);
            return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
        }
        if (opcode.equalsIgnoreCase("rev")) {
            ReorgFEDInstruction.parseUnaryInstruction(str, in, out);
            FEDInstruction.FederatedOutput fedOut = ReorgFEDInstruction.parseFedOutFlag(str, 3);
            return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
        }
        throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        ReorgOperator r_op = (ReorgOperator)this._optr;
        boolean isSpark = this.instString.startsWith("SPARK");
        if (!mo1.isFederated()) {
            throw new DMLRuntimeException("Federated Reorg: Federated input expected, but invoked w/ " + mo1.isFederated());
        }
        if (!mo1.isFederated(FTypes.FType.COL) && !mo1.isFederated(FTypes.FType.ROW)) {
            throw new DMLRuntimeException("Federation type " + mo1.getFedMapping().getType() + " is not supported for Reorg processing");
        }
        if (this.instOpcode.equals("r'")) {
            long id = FederationUtils.getNextFedDataID();
            FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{new MatrixCharacteristics(-1L, -1L), mo1.getDataType()});
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
            Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(this.getTID(), true, fr, fr1);
            if (this._fedOut != null && !this._fedOut.isForcedLocal()) {
                MatrixObject out = ec.getMatrixObject(this.output);
                long nnz = mo1.getNnz() != -1L ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
                out.getDataCharacteristics().setDimension(mo1.getNumColumns(), mo1.getNumRows()).setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
                out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
            } else {
                FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                Future<FederatedResponse>[] execResponse = mo1.getFedMapping().execute(this.getTID(), true, fr1, getRequest);
                ec.setMatrixOutput(this.output.getName(), FederationUtils.bind(execResponse, mo1.isFederated(FTypes.FType.ROW)));
            }
        } else {
            if (mo1.isFederated(FTypes.FType.PART)) {
                throw new DMLRuntimeException("Operation with opcode " + this.instOpcode + " is not supported with PART input");
            }
            if (this.instOpcode.equalsIgnoreCase("rev")) {
                long id = FederationUtils.getNextFedDataID();
                FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{new MatrixCharacteristics(-1L, -1L), mo1.getDataType()});
                FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, id, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
                Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(this.getTID(), true, fr, fr1);
                if (mo1.isFederated(FTypes.FType.ROW)) {
                    mo1.getFedMapping().reverseFedMap();
                }
                MatrixObject out = ec.getMatrixObject(this.output);
                long nnz = mo1.getNnz() != -1L ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
                out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo1.getNumColumns()).setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
                out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
                this.optionalForceLocal(out);
            } else if (this.instOpcode.equals("rdiag")) {
                RdiagResult result = mo1.getNumColumns() == 1L && mo1.getNumRows() != 1L ? this.rdiagV2M(mo1, r_op) : this.rdiagM2V(mo1, r_op);
                FederationMap diagFedMap = this.updateFedRanges(result);
                MatrixObject rdiag = ec.getMatrixObject(this.output);
                rdiag.getDataCharacteristics().set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1), mo1.getBlocksize());
                rdiag.setFedMapping(diagFedMap);
                this.optionalForceLocal(rdiag);
            }
        }
    }

    private FederationMap updateFedRanges(RdiagResult result) {
        FederationMap diagFedMap = result.getFedMap();
        Map<FederatedRange, int[]> dcs = result.getDcs();
        for (int i = 0; i < diagFedMap.getFederatedRanges().length; ++i) {
            int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]);
            diagFedMap.getFederatedRanges()[i].setBeginDim(0, diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0L || i == 0 ? 0L : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
            diagFedMap.getFederatedRanges()[i].setEndDim(0, diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + (long)newRange[0]);
            diagFedMap.getFederatedRanges()[i].setBeginDim(1, diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0L || i == 0 ? 0L : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
            diagFedMap.getFederatedRanges()[i].setEndDim(1, diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + (long)newRange[1]);
        }
        return diagFedMap;
    }

    private void optionalForceLocal(MatrixObject outputMatrixObject) {
        if (this._fedOut != null && this._fedOut.isForcedLocal()) {
            outputMatrixObject.acquireReadAndRelease();
            outputMatrixObject.getFedMapping().cleanup(this.getTID(), outputMatrixObject.getFedMapping().getID());
        }
    }

    private RdiagResult rdiagV2M(MatrixObject mo1, ReorgOperator r_op) {
        FederationMap fedMap = mo1.getFedMapping();
        boolean rowFed = mo1.isFederated(FTypes.FType.ROW);
        long varID = FederationUtils.getNextFedDataID();
        HashMap<FederatedRange, int[]> dcs = new HashMap<FederatedRange, int[]>();
        FederationMap diagFedMap = fedMap.mapParallel(varID, (range, data) -> {
            try {
                int[] nArray;
                FederatedRequest[] federatedRequestArray = new FederatedRequest[1];
                Object[] objectArray = new Object[1];
                long l = data.getVarID();
                if (rowFed) {
                    int[] nArray2 = new int[2];
                    nArray2[0] = range.getBeginDimsInt()[0];
                    nArray = nArray2;
                    nArray2[1] = range.getEndDimsInt()[0];
                } else {
                    int[] nArray3 = new int[2];
                    nArray3[0] = range.getBeginDimsInt()[1];
                    nArray = nArray3;
                    nArray3[1] = range.getEndDimsInt()[1];
                }
                objectArray[0] = new DiagMatrix(l, varID, r_op, nArray, rowFed, (int)mo1.getNumRows());
                federatedRequestArray[0] = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, objectArray);
                FederatedResponse response = data.executeFederatedOperation(federatedRequestArray).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
                int[] subRangeCharacteristics = (int[])response.getData()[0];
                Map map = dcs;
                synchronized (map) {
                    dcs.put((FederatedRange)range, subRangeCharacteristics);
                }
                return null;
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        });
        return new RdiagResult(diagFedMap, dcs);
    }

    private RdiagResult rdiagM2V(MatrixObject mo1, ReorgOperator r_op) {
        FederationMap fedMap = mo1.getFedMapping();
        boolean rowFed = mo1.isFederated(FTypes.FType.ROW);
        long varID = FederationUtils.getNextFedDataID();
        HashMap<FederatedRange, int[]> dcs = new HashMap<FederatedRange, int[]>();
        FederationMap diagFedMap = fedMap.mapParallel(varID, (range, data) -> {
            try {
                int[] nArray;
                FederatedRequest[] federatedRequestArray = new FederatedRequest[1];
                Object[] objectArray = new Object[1];
                long l = data.getVarID();
                if (rowFed) {
                    int[] nArray2 = new int[2];
                    nArray2[0] = range.getBeginDimsInt()[0];
                    nArray = nArray2;
                    nArray2[1] = range.getEndDimsInt()[0];
                } else {
                    int[] nArray3 = new int[2];
                    nArray3[0] = range.getBeginDimsInt()[1];
                    nArray = nArray3;
                    nArray3[1] = range.getEndDimsInt()[1];
                }
                objectArray[0] = new Rdiag(l, varID, r_op, nArray, rowFed);
                federatedRequestArray[0] = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, objectArray);
                FederatedResponse response = data.executeFederatedOperation(federatedRequestArray).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
                int[] subRangeCharacteristics = (int[])response.getData()[0];
                Map map = dcs;
                synchronized (map) {
                    dcs.put((FederatedRange)range, subRangeCharacteristics);
                }
                return null;
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        });
        return new RdiagResult(diagFedMap, dcs);
    }

    public static class DiagMatrix
    extends FederatedUDF {
        private static final long serialVersionUID = -3466926635958851402L;
        private final long _outputID;
        private final ReorgOperator _r_op;
        private final int _len;
        private final int[] _slice;
        private final boolean _rowFed;

        private DiagMatrix(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed, int len) {
            super(new long[]{input});
            this._outputID = outputID;
            this._r_op = r_op;
            this._len = len;
            this._rowFed = rowFed;
            this._slice = slice;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixBlock res;
            MatrixBlock mb = (MatrixBlock)((MatrixObject)data[0]).acquireReadAndRelease();
            MatrixBlock tmp = mb.reorgOperations(this._r_op, new MatrixBlock(), 0, 0, 0);
            if (this._rowFed) {
                res = new MatrixBlock(mb.getNumRows(), this._len, 0.0);
                res.copy(0, res.getNumRows() - 1, this._slice[0], this._slice[1] - 1, tmp, false);
            } else {
                res = new MatrixBlock(this._len, this._slice[1], 0.0);
                res.copy(this._slice[0], this._slice[1] - 1, 0, mb.getNumColumns() - 1, tmp, false);
            }
            MatrixObject mout = ExecutionContext.createMatrixObject(res);
            mout.setDiag(true);
            ec.setVariable(String.valueOf(this._outputID), mout);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, (Object)new int[]{res.getNumRows(), res.getNumColumns()});
        }

        @Override
        public List<Long> getOutputIds() {
            return new ArrayList<Long>(Arrays.asList(this._outputID));
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            LineageItem[] liUdfInputs = (LineageItem[])Arrays.stream(this.getInputIDs()).mapToObj(id -> ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
            CPOperand r_op = new CPOperand(this._r_op.fn.getClass().getSimpleName(), Types.ValueType.STRING, Types.DataType.SCALAR, true);
            CPOperand len = new CPOperand(String.valueOf(this._len), Types.ValueType.INT32, Types.DataType.SCALAR, true);
            CPOperand slice = new CPOperand(Arrays.toString(this._slice), Types.ValueType.STRING, Types.DataType.SCALAR, true);
            CPOperand rowFed = new CPOperand(String.valueOf(this._rowFed), Types.ValueType.BOOLEAN, Types.DataType.SCALAR, true);
            LineageItem[] otherInputs = LineageItemUtils.getLineage(ec, r_op, len, slice, rowFed);
            LineageItem[] liInputs = (LineageItem[])Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs)).toArray(LineageItem[]::new);
            return Pair.of((Object)String.valueOf(this._outputID), (Object)new LineageItem(this.getClass().getSimpleName(), liInputs));
        }
    }

    public static class Rdiag
    extends FederatedUDF {
        private static final long serialVersionUID = -3466926635958851402L;
        private final long _outputID;
        private final ReorgOperator _r_op;
        private final int[] _slice;
        private final boolean _rowFed;

        private Rdiag(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed) {
            super(new long[]{input});
            this._outputID = outputID;
            this._r_op = r_op;
            this._slice = slice;
            this._rowFed = rowFed;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixBlock mb = (MatrixBlock)((MatrixObject)data[0]).acquireReadAndRelease();
            MatrixBlock soresBlock = this._rowFed ? mb.slice(0, mb.getNumRows() - 1, this._slice[0], this._slice[1] - 1, new MatrixBlock()) : mb.slice(this._slice[0], this._slice[1] - 1);
            MatrixBlock res = soresBlock.reorgOperations(this._r_op, new MatrixBlock(), 0, 0, 0);
            MatrixObject mout = ExecutionContext.createMatrixObject(res);
            mout.setDiag(true);
            ec.setVariable(String.valueOf(this._outputID), mout);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, (Object)new int[]{res.getNumRows(), res.getNumColumns()});
        }

        @Override
        public List<Long> getOutputIds() {
            return new ArrayList<Long>(Arrays.asList(this._outputID));
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            LineageItem[] liUdfInputs = (LineageItem[])Arrays.stream(this.getInputIDs()).mapToObj(id -> ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
            CPOperand r_op = new CPOperand(this._r_op.fn.getClass().getSimpleName(), Types.ValueType.STRING, Types.DataType.SCALAR, true);
            CPOperand slice = new CPOperand(Arrays.toString(this._slice), Types.ValueType.STRING, Types.DataType.SCALAR, true);
            CPOperand rowFed = new CPOperand(String.valueOf(this._rowFed), Types.ValueType.BOOLEAN, Types.DataType.SCALAR, true);
            LineageItem[] otherInputs = LineageItemUtils.getLineage(ec, r_op, slice, rowFed);
            LineageItem[] liInputs = (LineageItem[])Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs)).toArray(LineageItem[]::new);
            return Pair.of((Object)String.valueOf(this._outputID), (Object)new LineageItem(this.getClass().getSimpleName(), liInputs));
        }
    }

    private class RdiagResult {
        FederationMap fedMap;
        Map<FederatedRange, int[]> dcs;

        public RdiagResult(FederationMap fedMap, Map<FederatedRange, int[]> dcs) {
            this.fedMap = fedMap;
            this.dcs = dcs;
        }

        public FederationMap getFedMap() {
            return this.fedMap;
        }

        public Map<FederatedRange, int[]> getDcs() {
            return this.dcs;
        }
    }
}

