/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.api.mlcontext;

import java.util.Set;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.sysds.api.ConfigurableAPI;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.api.mlcontext.MLContextException;
import org.apache.sysds.api.mlcontext.MLContextUtil;
import org.apache.sysds.api.mlcontext.MLResults;
import org.apache.sysds.api.mlcontext.ProjectInfo;
import org.apache.sysds.api.mlcontext.Script;
import org.apache.sysds.api.mlcontext.ScriptExecutor;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Expression;
import org.apache.sysds.parser.IntIdentifier;
import org.apache.sysds.parser.ParseInfo;
import org.apache.sysds.parser.StringIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.MLContextProxy;

public class MLContext
implements ConfigurableAPI {
    protected static Logger log = Logger.getLogger(MLContext.class);
    private SparkSession spark = null;
    private Script executionScript = null;
    private static MLContext activeMLContext = null;
    public static boolean welcomePrint = false;
    private InternalProxy internalProxy = new InternalProxy();
    private boolean explain = false;
    private boolean statistics = false;
    private boolean gpu = false;
    private boolean forceGPU = false;
    private int statisticsMaxHeavyHitters = 10;
    private ExplainLevel explainLevel = null;
    private ExecutionType executionType = ExecutionType.DRIVER_AND_SPARK;
    private boolean maintainSymbolTable = false;
    private boolean initBeforeExecution = true;

    public static MLContext getActiveMLContext() {
        return activeMLContext;
    }

    public MLContext(SparkSession spark) {
        this.initMLContext(spark);
    }

    public MLContext(SparkContext sparkContext) {
        this.initMLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate());
    }

    public MLContext(JavaSparkContext javaSparkContext) {
        this.initMLContext(SparkSession.builder().sparkContext(javaSparkContext.sc()).getOrCreate());
    }

    private void initMLContext(SparkSession spark) {
        try {
            MLContextUtil.verifySparkVersionSupported(spark);
        }
        catch (MLContextException e) {
            if (this.info() != null) {
                log.warn((Object)("Apache Spark " + this.info().minimumRecommendedSparkVersion() + " or above is recommended for SystemDS " + this.info().version()));
            }
            try {
                String minSparkVersion = MLContextUtil.getMinimumRecommendedSparkVersionFromPom();
                log.warn((Object)("Apache Spark " + minSparkVersion + " or above is recommended for this version of SystemDS."));
            }
            catch (MLContextException e1) {
                log.error((Object)"Minimum recommended Spark version could not be determined from SystemDS jar file manifest or pom.xml");
            }
        }
        if (!welcomePrint) {
            System.out.println(MLContextUtil.welcomeMessage());
            welcomePrint = true;
        }
        this.spark = spark;
        DMLScript.setGlobalExecMode(this.executionType.getExecMode());
        activeMLContext = this;
        MLContextProxy.setActive(true);
        MLContextUtil.setDefaultConfig();
        MLContextUtil.setCompilerConfig();
    }

    @Override
    public void resetConfig() {
        MLContextUtil.setDefaultConfig();
    }

    @Override
    public void setConfigProperty(String propertyName, String propertyValue) {
        DMLConfig config = ConfigurationManager.getDMLConfig();
        try {
            config.setTextValue(propertyName, propertyValue);
        }
        catch (DMLRuntimeException e) {
            throw new MLContextException(e);
        }
    }

    public MLResults execute(Script script) {
        ScriptExecutor scriptExecutor = new ScriptExecutor();
        scriptExecutor.setExecutionType(this.executionType);
        scriptExecutor.setExplain(this.explain);
        scriptExecutor.setExplainLevel(this.explainLevel);
        scriptExecutor.setGPU(this.gpu);
        scriptExecutor.setForceGPU(this.forceGPU);
        scriptExecutor.setStatistics(this.statistics);
        scriptExecutor.setStatisticsMaxHeavyHitters(this.statisticsMaxHeavyHitters);
        scriptExecutor.setInit(this.initBeforeExecution);
        if (this.initBeforeExecution) {
            this.initBeforeExecution = false;
        }
        scriptExecutor.setMaintainSymbolTable(this.maintainSymbolTable);
        return this.execute(script, scriptExecutor);
    }

    public MLResults execute(Script script, ScriptExecutor scriptExecutor) {
        try {
            this.executionScript = script;
            if (script.getName() == null || script.getName().equals("")) {
                script.setName(String.valueOf(System.currentTimeMillis()));
            }
            MLResults results = scriptExecutor.execute(script);
            return results;
        }
        catch (RuntimeException e) {
            throw new MLContextException("Exception when executing script", e);
        }
    }

    public void setExecutionScript(Script executionScript) {
        this.executionScript = executionScript;
    }

    public void setConfig(String configFilePath) {
        MLContextUtil.setConfig(configFilePath);
    }

    public SparkSession getSparkSession() {
        return this.spark;
    }

    public boolean isExplain() {
        return this.explain;
    }

    public void setExplain(boolean explain) {
        this.explain = explain;
    }

    public void setLineage(boolean lineage) {
        DMLScript.LINEAGE = lineage;
    }

    public void setLineage(LineageCacheConfig.ReuseCacheType reuse) {
        DMLScript.LINEAGE_REUSE = reuse;
        this.setLineage(true);
        LineageCacheConfig.setConfig(reuse);
    }

    public boolean isMaintainSymbolTable() {
        return this.maintainSymbolTable;
    }

    public void setMaintainSymbolTable(boolean maintainSymbolTable) {
        this.maintainSymbolTable = maintainSymbolTable;
    }

    public void setExplainLevel(ExplainLevel explainLevel) {
        this.explainLevel = explainLevel;
    }

    public void setExplainLevel(String explainLevel) {
        if (explainLevel != null) {
            for (ExplainLevel exp : ExplainLevel.values()) {
                String expString = exp.toString();
                if (!expString.equalsIgnoreCase(explainLevel)) continue;
                this.setExplainLevel(exp);
                return;
            }
        }
        throw new MLContextException("Failed to parse explain level: " + explainLevel + " (valid types: hops, runtime, recompile_hops, recompile_runtime).");
    }

    public void setGPU(boolean enable) {
        this.gpu = enable;
    }

    public void setForceGPU(boolean enable) {
        this.forceGPU = enable;
    }

    public boolean isGPU() {
        return this.gpu;
    }

    public boolean isForceGPU() {
        return this.forceGPU;
    }

    public InternalProxy getInternalProxy() {
        return this.internalProxy;
    }

    public boolean isStatistics() {
        return this.statistics;
    }

    public void setStatistics(boolean statistics) {
        DMLScript.STATISTICS = statistics;
        this.statistics = statistics;
    }

    public void setStatisticsMaxHeavyHitters(int maxHeavyHitters) {
        DMLScript.STATISTICS_COUNT = maxHeavyHitters;
        this.statisticsMaxHeavyHitters = maxHeavyHitters;
    }

    public void close() {
        SparkExecutionContext.resetSparkContextStatic();
        MLContextProxy.setActive(false);
        activeMLContext = null;
        try {
            DMLScript.cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
        }
        catch (Exception ex) {
            throw new MLContextException("Failed to cleanup working directories.", ex);
        }
        if (this.executionScript != null) {
            this.executionScript.clearAll();
        }
        this.resetConfig();
        this.spark = null;
    }

    public ProjectInfo info() {
        try {
            ProjectInfo projectInfo = ProjectInfo.getProjectInfo();
            return projectInfo;
        }
        catch (Exception e) {
            log.warn((Object)"Project information not available");
            return null;
        }
    }

    public String version() {
        if (this.info() == null) {
            return "Version not available";
        }
        return this.info().version();
    }

    public String buildTime() {
        if (this.info() == null) {
            return "Build time not available";
        }
        return this.info().buildTime();
    }

    public int getStatisticsMaxHeavyHitters() {
        return this.statisticsMaxHeavyHitters;
    }

    public boolean isInitBeforeExecution() {
        return this.initBeforeExecution;
    }

    public void setInitBeforeExecution(boolean initBeforeExecution) {
        this.initBeforeExecution = initBeforeExecution;
    }

    public ExecutionType getExecutionType() {
        return this.executionType;
    }

    public void setExecutionType(ExecutionType executionType) {
        DMLScript.setGlobalExecMode(executionType.getExecMode());
        this.executionType = executionType;
    }

    public class InternalProxy {
        public void setAppropriateVarsForRead(Expression source, String target) {
            boolean isReadExpression;
            boolean isTargetRegistered = this.isRegisteredAsInput(target);
            boolean bl = isReadExpression = source instanceof DataExpression && ((DataExpression)source).isRead();
            if (isTargetRegistered && isReadExpression) {
                MatrixObject mo;
                DataExpression exp = (DataExpression)source;
                exp.setCheckMetadata(false);
                Expression datatypeExp = ((DataExpression)source).getVarParam("data_type");
                String datatype = "matrix";
                if (datatypeExp != null) {
                    datatype = datatypeExp.toString();
                }
                if (datatype.compareToIgnoreCase("frame") != 0 && (mo = this.getMatrixObject(target)) != null) {
                    exp.addVarParam("rows", new IntIdentifier(mo.getNumRows(), (ParseInfo)source));
                    exp.addVarParam("cols", new IntIdentifier(mo.getNumColumns(), (ParseInfo)source));
                    exp.addVarParam("nnz", new IntIdentifier(mo.getNnz(), (ParseInfo)source));
                    exp.addVarParam("data_type", new StringIdentifier("matrix", source));
                    exp.addVarParam("value_type", new StringIdentifier("double", source));
                    if (mo.getMetaData() instanceof MetaDataFormat) {
                        MetaDataFormat metaData = (MetaDataFormat)mo.getMetaData();
                        exp.addVarParam("format", new StringIdentifier(metaData.getFileFormat().toString(), source));
                        if (metaData.getFileFormat() == Types.FileFormat.BINARY) {
                            exp.addVarParam("rows_in_block", new IntIdentifier(mo.getBlocksize(), (ParseInfo)source));
                            exp.addVarParam("cols_in_block", new IntIdentifier(mo.getBlocksize(), (ParseInfo)source));
                        }
                    }
                }
            }
        }

        private boolean isRegisteredAsInput(String parameterName) {
            Set<String> inputVariableNames;
            if (MLContext.this.executionScript != null && (inputVariableNames = MLContext.this.executionScript.getInputVariables()) != null) {
                return inputVariableNames.contains(parameterName);
            }
            return false;
        }

        private MatrixObject getMatrixObject(String parameterName) {
            LocalVariableMap symbolTable;
            if (MLContext.this.executionScript != null && (symbolTable = MLContext.this.executionScript.getSymbolTable()) != null) {
                Data data = symbolTable.get(parameterName);
                if (data instanceof MatrixObject) {
                    return (MatrixObject)data;
                }
                if (data instanceof ScalarObject) {
                    return null;
                }
            }
            throw new MLContextException("getMatrixObject not set for parameter: " + parameterName);
        }
    }

    public static enum ExecutionType {
        DRIVER,
        SPARK,
        HADOOP,
        DRIVER_AND_SPARK,
        DRIVER_AND_HADOOP;


        public Types.ExecMode getExecMode() {
            switch (this) {
                case DRIVER: {
                    return Types.ExecMode.SINGLE_NODE;
                }
                case SPARK: {
                    return Types.ExecMode.SPARK;
                }
            }
            return Types.ExecMode.HYBRID;
        }
    }

    public static enum ExplainLevel {
        NONE,
        HOPS,
        RUNTIME,
        RECOMPILE_HOPS,
        RECOMPILE_RUNTIME;


        public Explain.ExplainType getExplainType() {
            switch (this) {
                case NONE: {
                    return Explain.ExplainType.NONE;
                }
                case HOPS: {
                    return Explain.ExplainType.HOPS;
                }
                case RUNTIME: {
                    return Explain.ExplainType.RUNTIME;
                }
                case RECOMPILE_HOPS: {
                    return Explain.ExplainType.RECOMPILE_HOPS;
                }
                case RECOMPILE_RUNTIME: {
                    return Explain.ExplainType.RECOMPILE_RUNTIME;
                }
            }
            return Explain.ExplainType.HOPS;
        }
    }
}

