/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.variancethresholdselector;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModelData;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorParams;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class VarianceThresholdSelector
implements Estimator<VarianceThresholdSelector, VarianceThresholdSelectorModel>,
VarianceThresholdSelectorParams<VarianceThresholdSelector> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public VarianceThresholdSelector() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public VarianceThresholdSelectorModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String inputCol = this.getInputCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)value -> (Vector)value.getField(inputCol), (TypeInformation)VectorTypeInfo.INSTANCE);
        DataStream<VarianceThresholdSelectorModelData> modelData = DataStreamUtils.aggregate(inputData, new VarianceThresholdSelectorAggregator(this.getVarianceThreshold()), Types.TUPLE((TypeInformation[])new TypeInformation[]{Types.LONG, DenseVectorTypeInfo.INSTANCE, DenseVectorTypeInfo.INSTANCE}), TypeInformation.of(VarianceThresholdSelectorModelData.class));
        VarianceThresholdSelectorModel model = new VarianceThresholdSelectorModel().setModelData(tEnv.fromDataStream(modelData));
        ParamUtils.updateExistingParams(model, this.getParamMap());
        return model;
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    public static VarianceThresholdSelector load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (VarianceThresholdSelector)ReadWriteUtils.loadStageParam(path);
    }

    private static class VarianceThresholdSelectorAggregator
    implements AggregateFunction<Vector, Tuple3<Long, DenseVector, DenseVector>, VarianceThresholdSelectorModelData> {
        private final double varianceThreshold;

        public VarianceThresholdSelectorAggregator(double varianceThreshold) {
            this.varianceThreshold = varianceThreshold;
        }

        public Tuple3<Long, DenseVector, DenseVector> createAccumulator() {
            return Tuple3.of((Object)0L, (Object)new DenseVector(new double[0]), (Object)new DenseVector(new double[0]));
        }

        public Tuple3<Long, DenseVector, DenseVector> add(Vector vector, Tuple3<Long, DenseVector, DenseVector> numAndSums) {
            if ((Long)numAndSums.f0 == 0L) {
                numAndSums.f1 = new DenseVector(vector.size());
                numAndSums.f2 = new DenseVector(vector.size());
            }
            Tuple3<Long, DenseVector, DenseVector> tuple3 = numAndSums;
            tuple3.f0 = (Long)tuple3.f0 + 1L;
            BLAS.axpy(1.0, vector, (DenseVector)numAndSums.f1);
            for (int i = 0; i < vector.size(); ++i) {
                int n = i;
                ((DenseVector)numAndSums.f2).values[n] = ((DenseVector)numAndSums.f2).values[n] + vector.get(i) * vector.get(i);
            }
            return numAndSums;
        }

        public VarianceThresholdSelectorModelData getResult(Tuple3<Long, DenseVector, DenseVector> numAndSums) {
            long numRows = (Long)numAndSums.f0;
            DenseVector sumVector = (DenseVector)numAndSums.f1;
            DenseVector squareSumVector = (DenseVector)numAndSums.f2;
            Preconditions.checkState((numRows > 0L ? 1 : 0) != 0, (Object)"The training set is empty.");
            int[] indices = IntStream.range(0, sumVector.size()).filter(i -> squareSumVector.values[i] / (double)numRows - sumVector.values[i] / (double)numRows * (sumVector.values[i] / (double)numRows) > this.varianceThreshold).toArray();
            return new VarianceThresholdSelectorModelData(sumVector.size(), indices);
        }

        public Tuple3<Long, DenseVector, DenseVector> merge(Tuple3<Long, DenseVector, DenseVector> numAndSums1, Tuple3<Long, DenseVector, DenseVector> acc) {
            if ((Long)numAndSums1.f0 == 0L) {
                return acc;
            }
            if ((Long)acc.f0 == 0L) {
                return numAndSums1;
            }
            Tuple3<Long, DenseVector, DenseVector> tuple3 = acc;
            tuple3.f0 = (Long)tuple3.f0 + (Long)numAndSums1.f0;
            BLAS.axpy(1.0, (Vector)numAndSums1.f1, (DenseVector)acc.f1);
            BLAS.axpy(1.0, (Vector)numAndSums1.f2, (DenseVector)acc.f2);
            return acc;
        }
    }
}

