/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.builder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.GraphData;
import org.apache.flink.ml.builder.GraphExecutionHelper;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.builder.TableId;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

@PublicEvolving
public final class Graph
implements Estimator<Graph, GraphModel> {
    private static final long serialVersionUID = 6354253958813529308L;
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private final List<GraphNode> nodes;
    private final TableId[] estimatorInputIds;
    private final TableId[] modelInputIds;
    private final TableId[] outputIds;
    @Nullable
    private final TableId[] inputModelDataIds;
    @Nullable
    private final TableId[] outputModelDataIds;

    public Graph(List<GraphNode> nodes, TableId[] estimatorInputIds, TableId[] modelInputs, TableId[] outputs, TableId[] inputModelDataIds, TableId[] outputModelDataIds) {
        this.nodes = (List)Preconditions.checkNotNull(nodes);
        this.estimatorInputIds = (TableId[])Preconditions.checkNotNull((Object)estimatorInputIds);
        this.modelInputIds = (TableId[])Preconditions.checkNotNull((Object)modelInputs);
        this.outputIds = (TableId[])Preconditions.checkNotNull((Object)outputs);
        this.inputModelDataIds = inputModelDataIds;
        this.outputModelDataIds = outputModelDataIds;
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public GraphModel fit(Table ... inputTables) {
        GraphNode node;
        Preconditions.checkArgument((this.estimatorInputIds.length == inputTables.length ? 1 : 0) != 0, (String)"number of provided tables %s does not match the expected number of tables %s", (Object[])new Object[]{inputTables.length, this.estimatorInputIds.length});
        ArrayList<GraphNode> modelNodes = new ArrayList<GraphNode>();
        GraphExecutionHelper executionHelper = new GraphExecutionHelper(this.nodes);
        executionHelper.setTables(this.estimatorInputIds, inputTables);
        while ((node = executionHelper.pollNextReadyNode()) != null) {
            Stage<Object> stage = node.stage;
            if (node.stageType == GraphNode.StageType.ESTIMATOR) {
                stage = ((Estimator)stage).fit(executionHelper.getTables(node.estimatorInputIds));
            }
            if (node.inputModelDataIds != null) {
                Table[] nodeInputModelData = executionHelper.getTables(node.inputModelDataIds);
                ((Model)stage).setModelData(nodeInputModelData);
            }
            Table[] nodeOutputs = ((AlgoOperator)stage).transform(executionHelper.getTables(node.algoOpInputIds));
            executionHelper.setTables(node.outputIds, nodeOutputs);
            if (node.outputModelDataIds != null) {
                Table[] nodeOutputModelData = ((Model)stage).getModelData();
                executionHelper.setTables(node.outputModelDataIds, nodeOutputModelData);
            }
            modelNodes.add(new GraphNode(node.nodeId, stage, GraphNode.StageType.ALGO_OPERATOR, null, node.algoOpInputIds, node.outputIds, node.inputModelDataIds, node.outputModelDataIds));
        }
        return new GraphModel(modelNodes, this.modelInputIds, this.outputIds, this.inputModelDataIds, this.outputModelDataIds);
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override
    public void save(String path) throws IOException {
        GraphData graphData = new GraphData(this.nodes, this.estimatorInputIds, this.modelInputIds, this.outputIds, this.inputModelDataIds, this.outputModelDataIds);
        ReadWriteUtils.saveGraph(this, graphData, path);
    }

    public static Graph load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (Graph)ReadWriteUtils.loadGraph(tEnv, path, Graph.class.getName());
    }
}

