/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.parfor;

import java.util.Arrays;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.JobConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeMatrix;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeRemoteSparkWCompare;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.instructions.spark.functions.CopyMatrixBlockPairFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.io.InputOutputInfo;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.Statistics;

public class ResultMergeRemoteSpark
extends ResultMergeMatrix {
    private static final long serialVersionUID = -6924566953903424820L;
    private ExecutionContext _ec = null;
    private int _numMappers = -1;
    private int _numReducers = -1;

    public ResultMergeRemoteSpark(MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum, ExecutionContext ec, int numMappers, int numReducers) {
        super(out, in, outputFilename, accum);
        this._ec = ec;
        this._numMappers = numMappers;
        this._numReducers = numReducers;
    }

    @Override
    public MatrixObject executeSerialMerge() {
        return this.executeParallelMerge(this._numMappers);
    }

    @Override
    public MatrixObject executeParallelMerge(int par) {
        MatrixObject moNew = null;
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("ResultMerge (remote, spark): Execute serial merge for output " + ((MatrixObject)this._output).hashCode() + " (fname=" + ((MatrixObject)this._output).getFileName() + ")"));
        }
        try {
            if (this._inputs != null && ((MatrixObject[])this._inputs).length > 0) {
                MetaDataFormat metadata = (MetaDataFormat)((MatrixObject)this._output).getMetaData();
                DataCharacteristics mcOld = metadata.getDataCharacteristics();
                MatrixObject compare = mcOld.getNonZeros() == 0L ? null : (MatrixObject)this._output;
                RDDObject ro = this.executeMerge(compare, (MatrixObject[])this._inputs, mcOld.getRows(), mcOld.getCols(), mcOld.getBlocksize());
                moNew = new MatrixObject(((MatrixObject)this._output).getValueType(), this._outputFName);
                MatrixCharacteristics mc = new MatrixCharacteristics(mcOld);
                ((DataCharacteristics)mc).setNonZeros(this._isAccum ? -1L : this.computeNonZeros((MatrixObject)this._output, Arrays.asList((MatrixObject[])this._inputs)));
                moNew.setMetaData(new MetaDataFormat(mc, metadata.getFileFormat()));
                moNew.setRDDHandle(ro);
            } else {
                moNew = (MatrixObject)this._output;
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return moNew;
    }

    protected RDDObject executeMerge(MatrixObject compare, MatrixObject[] inputs, long rlen, long clen, int blen) {
        String jobname = "ParFor-RMSP";
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sec = (SparkExecutionContext)this._ec;
        boolean withCompare = compare != null;
        RDDObject ret = null;
        int numRed = ResultMergeRemoteSpark.determineNumReducers(rlen, clen, blen, this._numReducers);
        if (inputs == null || inputs.length == 0) {
            throw new DMLRuntimeException("Execute merge should never be called with no inputs.");
        }
        try {
            InputOutputInfo ii = InputOutputInfo.get(Types.DataType.MATRIX, Types.FileFormat.BINARY);
            JobConf job = new JobConf("test");
            job.setJobName(jobname);
            job.setInputFormat(ii.inputFormatClass);
            Path[] paths = new Path[inputs.length];
            for (int i = 0; i < paths.length; ++i) {
                inputs[i].exportData();
                paths[i] = new Path(inputs[i].getFileName());
                ResultMergeRemoteSpark.setRDDHandleForMerge(inputs[i], sec);
            }
            FileInputFormat.setInputPaths((JobConf)job, (Path[])paths);
            JavaPairRDD<MatrixIndexes, MatrixBlock> rdd = sec.getSparkContext().hadoopRDD(job, ii.inputFormatClass, ii.keyClass, ii.valueClass).mapPartitionsToPair(new CopyMatrixBlockPairFunction(true), true);
            JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
            if (withCompare) {
                JavaPairRDD<?, ?> compareRdd = sec.getRDDHandleForMatrixObject(compare, Types.FileFormat.BINARY);
                ResultMergeRemoteSparkWCompare cfun = new ResultMergeRemoteSparkWCompare(this._isAccum);
                out = rdd.groupByKey(numRed).join(compareRdd).mapToPair(cfun);
            } else {
                out = this._isAccum ? RDDAggregateUtils.sumByKeyStable(rdd, false) : RDDAggregateUtils.mergeByKey(rdd, false);
            }
            ret = new RDDObject(out);
            for (int i = 0; i < paths.length; ++i) {
                ret.addLineageChild(inputs[i].getRDDHandle());
            }
            if (withCompare) {
                ret.addLineageChild(compare.getRDDHandle());
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        Statistics.incrementNoOfCompiledSPInst();
        Statistics.incrementNoOfExecutedSPInst();
        if (DMLScript.STATISTICS) {
            Statistics.maintainCPHeavyHitters(jobname, System.nanoTime() - t0);
        }
        return ret;
    }

    private static int determineNumReducers(long rlen, long clen, int blen, long numRed) {
        long reducerGroups = Math.max(rlen / (long)blen, 1L) * Math.max(clen / (long)blen, 1L);
        return (int)Math.min(numRed, reducerGroups);
    }

    private static void setRDDHandleForMerge(MatrixObject mo, SparkExecutionContext sec) {
        InputOutputInfo iinfo = InputOutputInfo.get(Types.DataType.MATRIX, Types.FileFormat.BINARY);
        JavaPairRDD<? extends Writable, ? extends Writable> rdd = sec.getSparkContext().hadoopFile(mo.getFileName(), iinfo.inputFormatClass, iinfo.keyClass, iinfo.valueClass);
        RDDObject rddhandle = new RDDObject(rdd);
        rddhandle.setHDFSFile(true);
        mo.setRDDHandle(rddhandle);
    }
}

