/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.examples.classification;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

public class OnlineLogisticRegressionExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create((StreamExecutionEnvironment)env);
        List<Row> trainData1 = Arrays.asList(Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.1, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.2, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.3, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.4, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.5, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{11.0, 12.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{12.0, 11.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{13.0, 12.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{14.0, 12.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{15.0, 12.0}), 1.0}));
        List<Row> trainData2 = Arrays.asList(Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.2, 3.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.8, 1.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.7, 1.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.6, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.2, 2.0}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{14.0, 17.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{15.0, 10.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{16.0, 16.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{17.0, 10.0}), 1.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{18.0, 13.0}), 1.0}));
        List<Row> predictData = Arrays.asList(Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.8, 2.7}), 0.0}), Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{15.5, 11.2}), 1.0}));
        RowTypeInfo typeInfo = new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.DOUBLE}, new String[]{"features", "label"});
        PeriodicSourceFunction trainSource = new PeriodicSourceFunction(1000L, Arrays.asList(trainData1, trainData2));
        DataStreamSource trainStream = env.addSource((SourceFunction)trainSource, (TypeInformation)typeInfo);
        Table trainTable = tEnv.fromDataStream((DataStream)trainStream).as("features", new String[0]);
        PeriodicSourceFunction predictSource = new PeriodicSourceFunction(1000L, Collections.singletonList(predictData));
        DataStreamSource predictStream = env.addSource((SourceFunction)predictSource, (TypeInformation)typeInfo);
        Table predictTable = tEnv.fromDataStream((DataStream)predictStream).as("features", new String[0]);
        Row initModelData = Row.of((Object[])new Object[]{Vectors.dense((double[])new double[]{0.41233679404769874, -0.18088118293232122}), 0L});
        Table initModelDataTable = tEnv.fromDataStream((DataStream)env.fromElements((Object[])new Row[]{initModelData}));
        OnlineLogisticRegression olr = ((OnlineLogisticRegression)((OnlineLogisticRegression)((OnlineLogisticRegression)((OnlineLogisticRegression)((OnlineLogisticRegression)((OnlineLogisticRegression)new OnlineLogisticRegression().setFeaturesCol("features")).setLabelCol("label")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2))).setElasticNet(Double.valueOf(0.5))).setGlobalBatchSize(Integer.valueOf(10))).setInitialModelData(initModelDataTable);
        OnlineLogisticRegressionModel onlineModel = olr.fit(new Table[]{trainTable});
        Table outputTable = onlineModel.transform(new Table[]{predictTable})[0];
        CloseableIterator it = outputTable.execute().collect();
        while (it.hasNext()) {
            Row row = (Row)it.next();
            DenseVector features = (DenseVector)row.getField(olr.getFeaturesCol());
            Double expectedResult = (Double)row.getField(olr.getLabelCol());
            Double predictionResult = (Double)row.getField(olr.getPredictionCol());
            DenseVector rawPredictionResult = (DenseVector)row.getField(olr.getRawPredictionCol());
            System.out.printf("Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult);
        }
    }
}

