/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;

public class AggregateBinaryCPInstruction
extends BinaryCPInstruction {
    public boolean transposeLeft;
    public boolean transposeRight;

    private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(CPInstruction.CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
    }

    private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean transposeLeft, boolean transposeRight) {
        super(CPInstruction.CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
        this.transposeLeft = transposeLeft;
        this.transposeRight = transposeRight;
    }

    public static AggregateBinaryCPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        int numFields = InstructionUtils.checkNumFields(parts, 4, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        int k = Integer.parseInt(parts[4]);
        AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
        if (numFields == 6) {
            boolean isLeftTransposed = Boolean.parseBoolean(parts[5]);
            boolean isRightTransposed = Boolean.parseBoolean(parts[6]);
            return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed);
        }
        return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixBlock ret;
        MatrixBlock matBlock1 = ec.getMatrixInput(this.input1.getName());
        MatrixBlock matBlock2 = ec.getMatrixInput(this.input2.getName());
        AggregateBinaryOperator ab_op = (AggregateBinaryOperator)this._optr;
        if (matBlock1 instanceof CompressedMatrixBlock) {
            CompressedMatrixBlock main = (CompressedMatrixBlock)matBlock1;
            ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, this.transposeLeft, this.transposeRight);
        } else if (matBlock2 instanceof CompressedMatrixBlock) {
            CompressedMatrixBlock main = (CompressedMatrixBlock)matBlock2;
            ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, this.transposeLeft, this.transposeRight);
        } else {
            ReorgOperator r_op;
            if (this.transposeLeft) {
                r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
                matBlock1 = matBlock1.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
            }
            if (this.transposeRight) {
                r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
                matBlock2 = matBlock2.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
            }
            ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this.input2.getName());
        ec.setMatrixOutput(this.output.getName(), ret);
    }
}

