/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class ReblockTensorFunction
implements PairFlatMapFunction<Tuple2<TensorIndexes, TensorBlock>, TensorIndexes, TensorBlock> {
    private static final long serialVersionUID = 9118830682358813489L;
    private int _numDims;
    private long _newBlen;

    public ReblockTensorFunction(int numDims, long newBlen) {
        this._numDims = numDims;
        this._newBlen = newBlen;
    }

    public Iterator<Tuple2<TensorIndexes, TensorBlock>> call(Tuple2<TensorIndexes, TensorBlock> arg0) throws Exception {
        TensorIndexes ti = (TensorIndexes)arg0._1();
        TensorBlock tb = (TensorBlock)arg0._2();
        TensorCharacteristics tc = new TensorCharacteristics(tb.getLongDims(), (int)this._newBlen);
        long[] tensorIndexes = new long[this._numDims];
        for (int i = 0; i < tb.getNumDims(); ++i) {
            tensorIndexes[i] = 1L + (ti.getIndex(i) - 1L) * tc.getNumBlocks(i);
        }
        Arrays.fill(tensorIndexes, tb.getNumDims(), tensorIndexes.length, 1L);
        long[] zeroBasedTensorIndexes = new long[tb.getNumDims()];
        Arrays.fill(zeroBasedTensorIndexes, 1L);
        ArrayList<Tuple2> retVal = new ArrayList<Tuple2>();
        long numBlocks = tc.getNumBlocks();
        int[] offsets = new int[tb.getNumDims()];
        int i = 0;
        while ((long)i < numBlocks) {
            TensorBlock outBlock;
            int[] dims = new int[tb.getNumDims()];
            UtilFunctions.computeSliceInfo(tc, zeroBasedTensorIndexes, dims, offsets);
            if (tb.isBasic()) {
                outBlock = new TensorBlock(tb.getValueType(), dims);
            } else {
                Types.ValueType[] schema = new Types.ValueType[dims[1]];
                System.arraycopy(tb.getSchema(), offsets[1], schema, 0, dims[1]);
                outBlock = new TensorBlock(schema, dims);
            }
            tb.slice(offsets, outBlock);
            retVal.add(new Tuple2((Object)new TensorIndexes(tensorIndexes), (Object)outBlock));
            UtilFunctions.computeNextTensorIndexes(tc, tensorIndexes);
            UtilFunctions.computeNextTensorIndexes(tc, zeroBasedTensorIndexes);
            ++i;
        }
        return retVal.iterator();
    }
}

