/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

public class LibTensorAgg {
    public static boolean satisfiesMultiThreadingConstraints(BasicTensorBlock in, int k) {
        return k > 1 && in._vt != Types.ValueType.BOOLEAN;
    }

    public static void aggregateUnaryTensor(BasicTensorBlock in, BasicTensorBlock out, AggregateUnaryOperator uaop) {
        AggType aggType = LibTensorAgg.getAggType(uaop);
        if (in.isEmpty(false)) {
            LibTensorAgg.aggregateUnaryTensorEmpty(in, out, aggType);
            return;
        }
        int numThreads = uaop.getNumThreads();
        if (LibTensorAgg.satisfiesMultiThreadingConstraints(in, numThreads)) {
            try {
                int i;
                ExecutorService pool = CommonThreadPool.get(numThreads);
                ArrayList<PartialAggTask> tasks = new ArrayList<PartialAggTask>();
                ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(in.getDim(0), numThreads, false);
                int lb = 0;
                for (i = 0; i < blklens.size(); ++i) {
                    tasks.add(new PartialAggTask(in, out, aggType, uaop, lb, lb + blklens.get(i)));
                    lb += blklens.get(i).intValue();
                }
                pool.invokeAll(tasks);
                pool.shutdown();
                out.copy(((PartialAggTask)tasks.get(0)).getResult());
                for (i = 1; i < tasks.size(); ++i) {
                    LibTensorAgg.aggregateFinalResult(uaop.aggOp, out, ((PartialAggTask)tasks.get(i)).getResult());
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
        } else if (!in.isSparse()) {
            LibTensorAgg.aggregateUnaryTensorPartial(in, out, aggType, uaop.aggOp.increOp.fn, 0, in.getDim(0));
        } else {
            throw new NotImplementedException("Tensor aggregation not supported for sparse tensors.");
        }
    }

    private static void aggregateUnaryTensorEmpty(BasicTensorBlock in, BasicTensorBlock out, AggType optype) {
        double val = optype == AggType.SUM ? 0.0 : Double.NaN;
        out.set(new int[]{0, 0}, val);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static void aggregateBinaryTensor(BasicTensorBlock in, BasicTensorBlock aggVal, AggregateOperator aop) {
        if (in.getLength() != aggVal.getLength()) {
            throw new DMLRuntimeException("Binary tensor aggregation requires consistent numbers of cells (" + Arrays.toString(in._dims) + ", " + Arrays.toString(aggVal._dims) + ").");
        }
        if (aop.existsCorrection()) throw new DMLRuntimeException("Corrections not supported for tensors yet");
        if (!(aop.increOp.fn instanceof Plus)) throw new DMLRuntimeException("Binary aggregation of this type not supported for tensors yet");
        int[] first = new int[in.getNumDims()];
        switch (in.getValueType()) {
            case INT64: {
                aggVal.set(first, (Long)in.get(first) + (Long)aggVal.get(first));
                return;
            }
            case INT32: {
                aggVal.set(first, (Integer)in.get(first) + (Integer)aggVal.get(first));
                return;
            }
            default: {
                aggVal.set(0, 0, in.get(0, 0) + aggVal.get(0, 0));
                return;
            }
        }
    }

    private static AggType getAggType(AggregateUnaryOperator op) {
        ValueFunction vfn = op.aggOp.increOp.fn;
        if (vfn instanceof Plus) {
            return AggType.SUM;
        }
        return AggType.INVALID;
    }

    public static boolean isSupportedUnaryAggregateOperator(AggregateUnaryOperator op) {
        AggType type = LibTensorAgg.getAggType(op);
        return type != AggType.INVALID;
    }

    private static void aggregateUnaryTensorPartial(BasicTensorBlock in, BasicTensorBlock out, AggType aggtype, ValueFunction fn, int rl, int ru) {
        if (aggtype == AggType.SUM) {
            LibTensorAgg.sum(in, out, (Plus)fn, rl, ru);
        }
    }

    private static void aggregateFinalResult(AggregateOperator aop, BasicTensorBlock out, BasicTensorBlock partout) {
        if (aop.existsCorrection()) {
            throw new NotImplementedException();
        }
        out.incrementalAggregate(aop, partout);
    }

    private static void sum(BasicTensorBlock in, BasicTensorBlock out, Plus plus, int rl, int ru) {
        if (in.isSparse()) {
            throw new DMLRuntimeException("Sparse aggregation not implemented for Tensor");
        }
        switch (in.getValueType()) {
            case BOOLEAN: {
                out.set(0, 0, in.getDenseBlock().countNonZeros());
                break;
            }
            case STRING: {
                throw new DMLRuntimeException("Sum over string tensor is not supported.");
            }
            case FP64: 
            case FP32: {
                DenseBlock a = in.getDenseBlock();
                double sum = 0.0;
                for (int r = rl; r < ru; ++r) {
                    for (int c = 0; c < a.getCumODims(0); ++c) {
                        sum = plus.execute(sum, a.get(r, c));
                    }
                }
                out.set(0, 0, sum);
                break;
            }
            case INT64: 
            case INT32: 
            case UINT4: 
            case UINT8: {
                DenseBlock a = in.getDenseBlock();
                long sum = 0L;
                int[] ix = new int[a.numDims()];
                for (int r = rl; r < ru; ++r) {
                    ix[0] = r;
                    int c = 0;
                    while (c < a.getCumODims(0)) {
                        ix[ix.length - 1] = c++;
                        sum += a.getLong(ix);
                    }
                }
                out.set(new int[out.getNumDims()], sum);
                break;
            }
            case CHARACTER: 
            case UNKNOWN: {
                throw new NotImplementedException();
            }
        }
    }

    private static class PartialAggTask
    extends AggTask {
        private BasicTensorBlock _in;
        private BasicTensorBlock _ret;
        private AggType _aggtype;
        private AggregateUnaryOperator _uaop;
        private int _rl;
        private int _ru;

        protected PartialAggTask(BasicTensorBlock in, BasicTensorBlock ret, AggType aggtype, AggregateUnaryOperator uaop, int rl, int ru) {
            this._in = in;
            this._ret = ret;
            this._aggtype = aggtype;
            this._uaop = uaop;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Object call() {
            this._ret = new BasicTensorBlock(this._ret._vt, new int[]{this._ret.getDim(0), this._ret.getDim(1)});
            this._ret.allocateDenseBlock();
            LibTensorAgg.aggregateUnaryTensorPartial(this._in, this._ret, this._aggtype, this._uaop.aggOp.increOp.fn, this._rl, this._ru);
            return null;
        }

        public BasicTensorBlock getResult() {
            return this._ret;
        }
    }

    private static abstract class AggTask
    implements Callable<Object> {
        private AggTask() {
        }
    }

    private static enum AggType {
        SUM,
        INVALID;

    }
}

