/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.FederatedCostEstimator;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class HopRel {
    protected final Hop hopRef;
    protected final FEDInstruction.FederatedOutput fedOut;
    protected final FederatedCost cost;
    protected final Set<Long> costPointerSet = new HashSet<Long>();
    protected final List<HopRel> inputDependency = new ArrayList<HopRel>();

    public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, Map<Long, List<HopRel>> hopRelMemo) {
        this.hopRef = associatedHop;
        this.fedOut = fedOut;
        this.setInputDependency(hopRelMemo);
        this.cost = new FederatedCostEstimator().costEstimate(this, hopRelMemo);
    }

    public void addCostPointer(long hopID) {
        this.costPointerSet.add(hopID);
    }

    public boolean existingCostPointer(long currentHopID) {
        if (this.costPointerSet.contains(currentHopID)) {
            return this.costPointerSet.size() > 1;
        }
        return this.costPointerSet.size() > 0;
    }

    public boolean hasLocalOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.LOUT;
    }

    public boolean hasFederatedOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.FOUT;
    }

    public FEDInstruction.FederatedOutput getFederatedOutput() {
        return this.fedOut;
    }

    public List<HopRel> getInputDependency() {
        return this.inputDependency;
    }

    public Hop getHopRef() {
        return this.hopRef;
    }

    private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> hopRelMemo) {
        return hopRelMemo.get(hop.getHopID()).stream().filter(in -> in.fedOut == FEDInstruction.FederatedOutput.FOUT).findFirst().orElse(null);
    }

    private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop input) {
        return hopRelMemo.get(input.getHopID()).stream().min(Comparator.comparingDouble(a -> a.cost.getTotal())).orElseThrow(() -> new DMLException("No element in Memo Table found for input"));
    }

    private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo) {
        if (this.hopRef.getInput() != null && this.hopRef.getInput().size() > 0) {
            if (this.fedOut == FEDInstruction.FederatedOutput.FOUT && !this.hopRef.isFederatedDataOp()) {
                int lowestFOUTIndex = 0;
                HopRel lowestFOUTHopRel = this.getFOUTHopRel(this.hopRef.getInput().get(0), hopRelMemo);
                for (int i = 1; i < this.hopRef.getInput().size(); ++i) {
                    Hop input2 = this.hopRef.getInput(i);
                    HopRel foutHopRel = this.getFOUTHopRel(input2, hopRelMemo);
                    if (lowestFOUTHopRel == null) {
                        lowestFOUTHopRel = foutHopRel;
                        lowestFOUTIndex = i;
                        continue;
                    }
                    if (foutHopRel == null || !(foutHopRel.getCost() < lowestFOUTHopRel.getCost())) continue;
                    lowestFOUTHopRel = foutHopRel;
                    lowestFOUTIndex = i;
                }
                HopRel[] inputHopRels = new HopRel[this.hopRef.getInput().size()];
                for (int i = 0; i < this.hopRef.getInput().size(); ++i) {
                    if (i != lowestFOUTIndex) {
                        Hop input3 = this.hopRef.getInput(i);
                        inputHopRels[i] = this.getMinOfInput(hopRelMemo, input3);
                        continue;
                    }
                    inputHopRels[i] = lowestFOUTHopRel;
                }
                this.inputDependency.addAll(Arrays.asList(inputHopRels));
            } else {
                this.inputDependency.addAll(this.hopRef.getInput().stream().map(input -> this.getMinOfInput(hopRelMemo, (Hop)input)).collect(Collectors.toList()));
            }
        }
        this.validateInputDependency();
    }

    private void validateInputDependency() {
        for (int i = 0; i < this.inputDependency.size(); ++i) {
            if (this.inputDependency.get(i) != null) continue;
            throw new DMLException("HopRel input number " + i + " (" + this.hopRef.getInput(i) + ") is null for root: \n" + this);
        }
    }

    public double getCost() {
        return this.cost.getTotal();
    }

    public FederatedCost getCostObject() {
        return this.cost;
    }

    public String toString() {
        StringBuilder strB = new StringBuilder();
        strB.append(this.getClass().getSimpleName());
        strB.append(" {HopID: ");
        strB.append(this.hopRef.getHopID());
        strB.append(", Opcode: ");
        strB.append(this.hopRef.getOpString());
        strB.append(", FedOut: ");
        strB.append((Object)this.fedOut);
        strB.append(", Cost: ");
        strB.append(this.cost);
        strB.append(", Number of inputs: ");
        strB.append(this.inputDependency.size());
        strB.append("}");
        return strB.toString();
    }
}

