/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.LegacyEncoder;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;

public class MultiColumnEncoder
implements Encoder {
    protected static final Log LOG = LogFactory.getLog((String)MultiColumnEncoder.class.getName());
    private static final boolean MULTI_THREADED = true;
    private List<ColumnEncoderComposite> _columnEncoders;
    private EncoderMVImpute _legacyMVImpute = null;
    private EncoderOmit _legacyOmit = null;
    private int _colOffset = 0;
    private FrameBlock _meta = null;
    private int APPLY_BLOCKSIZE = 0;
    public static int BUILD_BLOCKSIZE = 0;

    public void setApplyBlockSize(int blk) {
        this.APPLY_BLOCKSIZE = blk;
    }

    public void setBuildBlockSize(int blk) {
        BUILD_BLOCKSIZE = blk;
    }

    public MultiColumnEncoder(List<ColumnEncoderComposite> columnEncoders) {
        this._columnEncoders = columnEncoders;
    }

    public MultiColumnEncoder() {
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
    }

    public MatrixBlock encode(FrameBlock in) {
        return this.encode(in, 1);
    }

    public MatrixBlock encode(FrameBlock in, int k) {
        MatrixBlock out;
        try {
            this.build(in, k);
            if (this._legacyMVImpute != null) {
                this._meta = this.getMetaData(new FrameBlock(in.getNumColumns(), Types.ValueType.STRING));
                this.initMetaData(this._meta);
            }
            out = this.apply(in, k);
        }
        catch (Exception ex) {
            LOG.error((Object)("Failed transform-encode frame with \n" + this));
            throw ex;
        }
        return out;
    }

    @Override
    public void build(FrameBlock in) {
        this.build(in, 1);
    }

    public void build(FrameBlock in, int k) {
        if (k > 1) {
            this.buildMT(in, k);
        } else {
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.build(in);
                columnEncoder.updateAllDCEncoders();
            }
        }
        this.legacyBuild(in);
    }

    private void buildMT(FrameBlock in, int k) {
        int blockSize = BUILD_BLOCKSIZE <= 0 ? in.getNumRows() : BUILD_BLOCKSIZE;
        ArrayList<Callable<Integer>> tasks = new ArrayList<Callable<Integer>>();
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            if (blockSize != in.getNumRows()) {
                ArrayList partials = new ArrayList();
                for (ColumnEncoderComposite encoder : this._columnEncoders) {
                    List<Callable<Object>> partialBuildTasks = encoder.getPartialBuildTasks(in, blockSize);
                    if (partialBuildTasks == null) {
                        partials.add(null);
                        continue;
                    }
                    partials.add(partialBuildTasks.stream().map(pool::submit).collect(Collectors.toList()));
                }
                for (int e = 0; e < this._columnEncoders.size(); ++e) {
                    List partial = (List)partials.get(e);
                    if (partial == null) continue;
                    tasks.add(new ColumnMergeBuildPartialTask(this._columnEncoders.get(e), partial));
                }
            } else {
                for (ColumnEncoderComposite e : this._columnEncoders) {
                    tasks.add(new ColumnBuildTask(e, in));
                }
            }
            List rtasks = pool.invokeAll(tasks);
            pool.shutdown();
            for (Future t : rtasks) {
                t.get();
            }
        }
        catch (InterruptedException | ExecutionException e) {
            LOG.error((Object)"MT Column encode failed");
            e.printStackTrace();
        }
    }

    public void legacyBuild(FrameBlock in) {
        if (this._legacyOmit != null) {
            this._legacyOmit.build(in);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.build(in);
        }
    }

    public MatrixBlock apply(FrameBlock in) {
        return this.apply(in, 1);
    }

    public MatrixBlock apply(FrameBlock in, int k) {
        int numCols = in.getNumColumns() + this.getNumExtraCols();
        long estNNz = (long)in.getNumColumns() * (long)in.getNumRows();
        boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz);
        MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
        return this.apply(in, out, 0, k);
    }

    @Override
    public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol) {
        return this.apply(in, out, outputCol, 1);
    }

    public MatrixBlock apply(FrameBlock in, MatrixBlock out, int outputCol, int k) {
        int numEncoders = this.getFromAll(ColumnEncoderComposite.class, ColumnEncoder::getColID).size();
        if (in.getNumColumns() != numEncoders) {
            throw new DMLRuntimeException("Not every column in has a CompositeEncoder. Please make sure every column has a encoder or slice the input accordingly");
        }
        out.allocateBlock();
        if (out.isInSparseFormat()) {
            SparseBlock block = out.getSparseBlock();
            if (!(block instanceof SparseBlockMCSR)) {
                throw new RuntimeException("Transform apply currently only supported for MCSR sparse and dense output Matrices");
            }
            for (int r = 0; r < out.getNumRows(); ++r) {
                block.allocate(r, in.getNumColumns());
            }
        }
        if (k > 1) {
            this.applyMT(in, out, outputCol, k);
        } else {
            int offset = outputCol;
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset);
                if (!columnEncoder.hasEncoder(ColumnEncoderDummycode.class)) continue;
                offset += columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
            }
        }
        out.recomputeNonZeros();
        if (this._legacyOmit != null) {
            out = this._legacyOmit.apply(in, out);
        }
        if (this._legacyMVImpute != null) {
            out = this._legacyMVImpute.apply(in, out);
        }
        return out;
    }

    private void applyMT(FrameBlock in, MatrixBlock out, int outputCol, int k) {
        try {
            ExecutorService pool = CommonThreadPool.get(k);
            ArrayList<ColumnApplyTask> tasks = new ArrayList<ColumnApplyTask>();
            int offset = outputCol;
            int blockSize = this.APPLY_BLOCKSIZE <= 0 ? in.getNumRows() : this.APPLY_BLOCKSIZE;
            for (ColumnEncoderComposite e : this._columnEncoders) {
                for (int i = 0; i < in.getNumRows(); i += blockSize) {
                    tasks.add(new ColumnApplyTask(e, in, out, e._colID - 1 + offset, i, blockSize));
                }
                if (in.getNumRows() % blockSize != 0) {
                    tasks.add(new ColumnApplyTask(e, in, out, e._colID - 1 + offset, in.getNumRows() - in.getNumRows() % blockSize, -1));
                }
                if (!e.hasEncoder(ColumnEncoderDummycode.class)) continue;
                offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
            }
            List rtasks = pool.invokeAll(tasks);
            pool.shutdown();
            for (Future t : rtasks) {
                t.get();
            }
        }
        catch (InterruptedException | ExecutionException e) {
            LOG.error((Object)"MT Column encode failed");
            e.printStackTrace();
        }
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        if (this._meta != null) {
            return this._meta;
        }
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.getMetaData(meta);
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.getMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.getMetaData(meta);
        }
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.initMetaData(meta);
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.initMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.initMetaData(meta);
        }
    }

    @Override
    public void prepareBuildPartial() {
        for (Encoder encoder : this._columnEncoders) {
            encoder.prepareBuildPartial();
        }
    }

    @Override
    public void buildPartial(FrameBlock in) {
        for (Encoder encoder : this._columnEncoders) {
            encoder.buildPartial(in);
        }
    }

    public MatrixBlock getColMapping(FrameBlock meta) {
        MatrixBlock out = new MatrixBlock(meta.getNumColumns(), 3, false);
        List<ColumnEncoderDummycode> dc = this.getColumnEncoders(ColumnEncoderDummycode.class);
        int ni = 0;
        for (int i = 0; i < out.getNumRows(); ++i) {
            int colID = i + 1;
            int nColID = ni + 1;
            List encoder = dc.stream().filter(e -> e.getColID() == colID).collect(Collectors.toList());
            assert (encoder.size() <= 1);
            ni = encoder.size() == 1 ? (int)((long)ni + meta.getColumnMetadata(i).getNumDistinct()) : ++ni;
            out.quickSetValue(i, 0, colID);
            out.quickSetValue(i, 1, nColID);
            out.quickSetValue(i, 2, ni);
        }
        return out;
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int offset) {
        this._columnEncoders.forEach(encoder -> encoder.updateIndexRanges(beginDims, endDims, offset));
        if (this._legacyOmit != null) {
            this._legacyOmit.updateIndexRanges(beginDims, endDims);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.updateIndexRanges(beginDims, endDims);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeBoolean(this._legacyMVImpute != null);
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.writeExternal(out);
        }
        out.writeBoolean(this._legacyOmit != null);
        if (this._legacyOmit != null) {
            this._legacyOmit.writeExternal(out);
        }
        out.writeInt(this._colOffset);
        out.writeInt(this._columnEncoders.size());
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            out.writeInt(columnEncoder._colID);
            columnEncoder.writeExternal(out);
        }
        out.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(out);
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
        if (in.readBoolean()) {
            this._legacyMVImpute = new EncoderMVImpute();
            this._legacyMVImpute.readExternal(in);
        }
        if (in.readBoolean()) {
            this._legacyOmit = new EncoderOmit();
            this._legacyOmit.readExternal(in);
        }
        this._colOffset = in.readInt();
        int encodersSize = in.readInt();
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
        for (int i = 0; i < encodersSize; ++i) {
            int colID = in.readInt();
            ColumnEncoderComposite columnEncoder = new ColumnEncoderComposite();
            columnEncoder.readExternal(in);
            columnEncoder.setColID(colID);
            this._columnEncoders.add(columnEncoder);
        }
        if (in.readBoolean()) {
            FrameBlock meta = new FrameBlock();
            meta.readFields(in);
            this._meta = meta;
        }
    }

    /*
     * WARNING - void declaration
     */
    public <T extends ColumnEncoder> List<T> getColumnEncoders(Class<T> type) {
        ArrayList<T> ret = new ArrayList<T>();
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            void var4_4;
            if (columnEncoder.getClass().equals(ColumnEncoderComposite.class) && type != ColumnEncoderComposite.class) {
                T t = ((ColumnEncoderComposite)columnEncoder).getEncoder(type);
            }
            if (var4_4 == null || !var4_4.getClass().equals(type)) continue;
            ret.add(type.cast(var4_4));
        }
        return ret;
    }

    public <T extends ColumnEncoder> T getColumnEncoder(int colID, Class<T> type) {
        for (ColumnEncoder encoder : this.getColumnEncoders(type)) {
            if (encoder._colID != colID) continue;
            return (T)encoder;
        }
        return null;
    }

    public <T extends ColumnEncoder, E> List<E> getFromAll(Class<T> type, Function<? super T, ? extends E> mapper) {
        return this.getColumnEncoders(type).stream().map(mapper).collect(Collectors.toList());
    }

    public <T extends ColumnEncoder> int[] getFromAllIntArray(Class<T> type, Function<? super T, ? extends Integer> mapper) {
        return this.getFromAll(type, mapper).stream().mapToInt(i -> i).toArray();
    }

    public <T extends ColumnEncoder> double[] getFromAllDoubleArray(Class<T> type, Function<? super T, ? extends Double> mapper) {
        return this.getFromAll(type, mapper).stream().mapToDouble(i -> i).toArray();
    }

    public List<ColumnEncoderComposite> getColumnEncoders() {
        return this._columnEncoders;
    }

    public List<ColumnEncoderComposite> getCompositeEncodersForID(int colID) {
        return this._columnEncoders.stream().filter(encoder -> encoder._colID == colID).collect(Collectors.toList());
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes(int colID) {
        HashSet set = new HashSet();
        for (ColumnEncoderComposite encoderComp : this._columnEncoders) {
            if (encoderComp._colID != colID && colID != -1) continue;
            for (ColumnEncoder encoder : encoderComp.getEncoders()) {
                set.add(encoder.getClass());
            }
        }
        return new ArrayList<Class<? extends ColumnEncoder>>(set);
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes() {
        return this.getEncoderTypes(-1);
    }

    public int getNumExtraCols() {
        List<ColumnEncoderDummycode> dc = this.getColumnEncoders(ColumnEncoderDummycode.class);
        if (dc.isEmpty()) {
            return 0;
        }
        return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
    }

    public int getNumExtraCols(IndexRange ixRange) {
        List dc = this.getColumnEncoders(ColumnEncoderDummycode.class).stream().filter(dce -> ixRange.inColRange(dce._colID)).collect(Collectors.toList());
        if (dc.isEmpty()) {
            return 0;
        }
        return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
    }

    public <T extends ColumnEncoder> boolean containsEncoderForID(int colID, Class<T> type) {
        return this.getColumnEncoders(type).stream().anyMatch(encoder -> encoder.getColID() == colID);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Class<T> type, Consumer<? super T> function) {
        this.getColumnEncoders(type).forEach(function);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Consumer<? super ColumnEncoderComposite> function) {
        this.getColumnEncoders().forEach(function);
    }

    public MultiColumnEncoder subRangeEncoder(IndexRange ixRange) {
        ArrayList<ColumnEncoderComposite> encoders = new ArrayList<ColumnEncoderComposite>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.addAll(this.getCompositeEncodersForID((int)i));
        }
        MultiColumnEncoder subRangeEncoder = new MultiColumnEncoder(encoders);
        subRangeEncoder._colOffset = (int)(-ixRange.colStart) + 1;
        if (this._legacyOmit != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyOmit.subRangeEncoder(ixRange));
        }
        if (this._legacyMVImpute != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyMVImpute.subRangeEncoder(ixRange));
        }
        return subRangeEncoder;
    }

    public <T extends ColumnEncoder> MultiColumnEncoder subRangeEncoder(IndexRange ixRange, Class<T> type) {
        ArrayList<T> encoders = new ArrayList<T>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.add(this.getColumnEncoder((int)i, type));
        }
        if (type.equals(ColumnEncoderComposite.class)) {
            return new MultiColumnEncoder(encoders.stream().map(e -> (ColumnEncoderComposite)e).collect(Collectors.toList()));
        }
        return new MultiColumnEncoder(encoders.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList()));
    }

    public void mergeReplace(MultiColumnEncoder multiEncoder) {
        for (ColumnEncoderComposite otherEncoder : multiEncoder._columnEncoders) {
            ColumnEncoderComposite encoder = (ColumnEncoderComposite)this.getColumnEncoder(otherEncoder._colID, otherEncoder.getClass());
            if (encoder != null) {
                this._columnEncoders.remove(encoder);
            }
            this._columnEncoders.add(otherEncoder);
        }
    }

    public void mergeAt(Encoder other, int columnOffset, int row) {
        if (other instanceof MultiColumnEncoder) {
            for (ColumnEncoder columnEncoder : ((MultiColumnEncoder)other)._columnEncoders) {
                this.addEncoder(columnEncoder, columnOffset);
            }
            this.legacyMergeAt((MultiColumnEncoder)other, row, columnOffset + 1);
        } else {
            this.addEncoder((ColumnEncoder)other, columnOffset);
        }
    }

    private void legacyMergeAt(MultiColumnEncoder other, int row, int col) {
        if (other._legacyOmit != null) {
            other._legacyOmit.shiftCols(col - 1);
        }
        if (other._legacyOmit != null) {
            if (this._legacyOmit == null) {
                this._legacyOmit = new EncoderOmit();
            }
            this._legacyOmit.mergeAt(other._legacyOmit, row, col);
        }
        if (other._legacyMVImpute != null) {
            other._legacyMVImpute.shiftCols(col - 1);
        }
        if (this._legacyMVImpute != null && other._legacyMVImpute != null) {
            this._legacyMVImpute.mergeAt(other._legacyMVImpute, row, col);
        } else if (this._legacyMVImpute == null) {
            this._legacyMVImpute = other._legacyMVImpute;
        }
    }

    private void addEncoder(ColumnEncoder encoder, int columnOffset) {
        int colId = encoder._colID + columnOffset;
        Object presentEncoder = this.getColumnEncoder(colId, encoder.getClass());
        if (presentEncoder != null) {
            encoder.shiftCol(columnOffset);
            ((ColumnEncoder)presentEncoder).mergeAt(encoder);
        } else {
            ColumnEncoderComposite presentComposite = this.getColumnEncoder(colId, ColumnEncoderComposite.class);
            if (presentComposite != null) {
                encoder.shiftCol(columnOffset);
                presentComposite.mergeAt(encoder);
            } else {
                encoder.shiftCol(columnOffset);
                if (encoder instanceof ColumnEncoderComposite) {
                    this._columnEncoders.add((ColumnEncoderComposite)encoder);
                } else {
                    this._columnEncoders.add(new ColumnEncoderComposite(encoder));
                }
            }
        }
    }

    public <T extends LegacyEncoder> void addReplaceLegacyEncoder(T encoder) {
        if (encoder.getClass() == EncoderMVImpute.class) {
            this._legacyMVImpute = (EncoderMVImpute)encoder;
        } else if (encoder.getClass().equals(EncoderOmit.class)) {
            this._legacyOmit = (EncoderOmit)encoder;
        } else {
            throw new DMLRuntimeException("Tried to add non legacy Encoder");
        }
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return this._legacyMVImpute != null;
        }
        if (type.equals(EncoderOmit.class)) {
            return this._legacyOmit != null;
        }
        assert (false);
        return false;
    }

    public <T extends LegacyEncoder> T getLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyMVImpute));
        }
        if (type.equals(EncoderOmit.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyOmit));
        }
        assert (false);
        return null;
    }

    public void applyColumnOffset() {
        this.applyToAll(e -> e.shiftCol(this._colOffset));
        if (this._legacyOmit != null) {
            this._legacyOmit.shiftCols(this._colOffset);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.shiftCols(this._colOffset);
        }
    }

    private static class ColumnMergeBuildPartialTask
    implements Callable<Integer> {
        private final ColumnEncoderComposite _encoder;
        private final List<Future<Object>> _partials;

        protected ColumnMergeBuildPartialTask(ColumnEncoderComposite encoder, List<Future<Object>> partials) {
            this._encoder = encoder;
            this._partials = partials;
        }

        @Override
        public Integer call() throws Exception {
            this._encoder.mergeBuildPartial(this._partials, 0, this._partials.size());
            this._encoder.updateAllDCEncoders();
            return 1;
        }
    }

    private static class ColumnBuildTask
    implements Callable<Integer> {
        private final ColumnEncoder _encoder;
        private final FrameBlock _input;

        protected ColumnBuildTask(ColumnEncoder encoder, FrameBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Integer call() throws Exception {
            this._encoder.build(this._input);
            if (this._encoder instanceof ColumnEncoderComposite) {
                ((ColumnEncoderComposite)this._encoder).updateAllDCEncoders();
            }
            return 1;
        }
    }

    private static class ColumnApplyTask
    implements Callable<Integer> {
        private final ColumnEncoder _encoder;
        private final FrameBlock _input;
        private final MatrixBlock _out;
        private final int _columnOut;
        private int _rowStart = 0;
        private int _blk = -1;

        protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock input, MatrixBlock out, int columnOut) {
            this._encoder = encoder;
            this._input = input;
            this._out = out;
            this._columnOut = columnOut;
        }

        protected ColumnApplyTask(ColumnEncoder encoder, FrameBlock input, MatrixBlock out, int columnOut, int rowStart, int blk) {
            this(encoder, input, out, columnOut);
            this._rowStart = rowStart;
            this._blk = blk;
        }

        @Override
        public Integer call() throws Exception {
            this._encoder.apply(this._input, this._out, this._columnOut, this._rowStart, this._blk);
            return 1;
        }
    }
}

