/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.neuralnet;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.neuralnet.Layer;
import ciir.umass.edu.learning.neuralnet.Neuron;
import ciir.umass.edu.learning.neuralnet.PropParameter;
import ciir.umass.edu.learning.neuralnet.Synapse;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

public class RankNet
extends Ranker {
    public static int nIteration = 100;
    public static int nHiddenLayer = 1;
    public static int nHiddenNodePerLayer = 10;
    public static double learningRate = 5.0E-5;
    protected List<Layer> layers = new ArrayList<Layer>();
    protected Layer inputLayer = null;
    protected Layer outputLayer = null;
    protected List<List<Double>> bestModelOnValidation = new ArrayList<List<Double>>();
    protected int totalPairs = 0;
    protected int misorderedPairs = 0;
    protected double error = 0.0;
    protected double lastError = Double.MAX_VALUE;
    protected int straightLoss = 0;

    public RankNet() {
    }

    public RankNet(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    protected void setInputOutput(int nInput, int nOutput) {
        this.inputLayer = new Layer(nInput + 1);
        this.outputLayer = new Layer(nOutput);
        this.layers.clear();
        this.layers.add(this.inputLayer);
        this.layers.add(this.outputLayer);
    }

    protected void setInputOutput(int nInput, int nOutput, int nType) {
        this.inputLayer = new Layer(nInput + 1, nType);
        this.outputLayer = new Layer(nOutput, nType);
        this.layers.clear();
        this.layers.add(this.inputLayer);
        this.layers.add(this.outputLayer);
    }

    protected void addHiddenLayer(int size) {
        this.layers.add(this.layers.size() - 1, new Layer(size));
    }

    protected void wire() {
        int j;
        int i;
        for (i = 0; i < this.inputLayer.size() - 1; ++i) {
            for (j = 0; j < this.layers.get(1).size(); ++j) {
                this.connect(0, i, 1, j);
            }
        }
        for (i = 1; i < this.layers.size() - 1; ++i) {
            for (j = 0; j < this.layers.get(i).size(); ++j) {
                for (int k = 0; k < this.layers.get(i + 1).size(); ++k) {
                    this.connect(i, j, i + 1, k);
                }
            }
        }
        for (i = 1; i < this.layers.size(); ++i) {
            for (j = 0; j < this.layers.get(i).size(); ++j) {
                this.connect(0, this.inputLayer.size() - 1, i, j);
            }
        }
    }

    protected void connect(int sourceLayer, int sourceNeuron, int targetLayer, int targetNeuron) {
        new Synapse(this.layers.get(sourceLayer).get(sourceNeuron), this.layers.get(targetLayer).get(targetNeuron));
    }

    protected void addInput(DataPoint p) {
        for (int k = 0; k < this.inputLayer.size() - 1; ++k) {
            this.inputLayer.get(k).addOutput(p.getFeatureValue(this.features[k]));
        }
        this.inputLayer.get(this.inputLayer.size() - 1).addOutput(1.0);
    }

    protected void propagate(int i) {
        for (int k = 1; k < this.layers.size(); ++k) {
            this.layers.get(k).computeOutput(i);
        }
    }

    protected int[][] batchFeedForward(RankList rl) {
        int[][] pairMap = new int[rl.size()][];
        for (int i = 0; i < rl.size(); ++i) {
            this.addInput(rl.get(i));
            this.propagate(i);
            int count = 0;
            for (int j = 0; j < rl.size(); ++j) {
                if (!(rl.get(i).getLabel() > rl.get(j).getLabel())) continue;
                ++count;
            }
            pairMap[i] = new int[count];
            int k = 0;
            for (int j = 0; j < rl.size(); ++j) {
                if (!(rl.get(i).getLabel() > rl.get(j).getLabel())) continue;
                pairMap[i][k++] = j;
            }
        }
        return pairMap;
    }

    protected void batchBackPropagate(int[][] pairMap, float[][] pairWeight) {
        for (int i = 0; i < pairMap.length; ++i) {
            int j;
            PropParameter p = new PropParameter(i, pairMap);
            this.outputLayer.computeDelta(p);
            for (j = this.layers.size() - 2; j >= 1; --j) {
                this.layers.get(j).updateDelta(p);
            }
            this.outputLayer.updateWeight(p);
            for (j = this.layers.size() - 2; j >= 1; --j) {
                this.layers.get(j).updateWeight(p);
            }
        }
    }

    protected void clearNeuronOutputs() {
        for (int k = 0; k < this.layers.size(); ++k) {
            this.layers.get(k).clearOutputs();
        }
    }

    protected float[][] computePairWeight(int[][] pairMap, RankList rl) {
        return null;
    }

    protected RankList internalReorder(RankList rl) {
        return rl;
    }

    protected void saveBestModelOnValidation() {
        for (int i = 0; i < this.layers.size() - 1; ++i) {
            List<Double> l = this.bestModelOnValidation.get(i);
            l.clear();
            for (int j = 0; j < this.layers.get(i).size(); ++j) {
                Neuron n = this.layers.get(i).get(j);
                for (int k = 0; k < n.getOutLinks().size(); ++k) {
                    l.add(n.getOutLinks().get(k).getWeight());
                }
            }
        }
    }

    protected void restoreBestModelOnValidation() {
        try {
            for (int i = 0; i < this.layers.size() - 1; ++i) {
                List<Double> l = this.bestModelOnValidation.get(i);
                int c = 0;
                for (int j = 0; j < this.layers.get(i).size(); ++j) {
                    Neuron n = this.layers.get(i).get(j);
                    for (int k = 0; k < n.getOutLinks().size(); ++k) {
                        n.getOutLinks().get(k).setWeight(l.get(c++));
                    }
                }
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in NeuralNetwork.restoreBestModelOnValidation(): ", ex);
        }
    }

    protected double crossEntropy(double o1, double o2, double targetValue) {
        double oij = o1 - o2;
        double ce = -targetValue * oij + SimpleMath.logBase2(1.0 + Math.exp(oij));
        return ce;
    }

    protected void estimateLoss() {
        this.misorderedPairs = 0;
        this.error = 0.0;
        for (int j = 0; j < this.samples.size(); ++j) {
            RankList rl = (RankList)this.samples.get(j);
            for (int k = 0; k < rl.size() - 1; ++k) {
                double o1 = this.eval(rl.get(k));
                for (int l = k + 1; l < rl.size(); ++l) {
                    if (!(rl.get(k).getLabel() > rl.get(l).getLabel())) continue;
                    double o2 = this.eval(rl.get(l));
                    this.error += this.crossEntropy(o1, o2, 1.0);
                    if (!(o1 < o2)) continue;
                    ++this.misorderedPairs;
                }
            }
        }
        this.lastError = this.error = SimpleMath.round(this.error / (double)this.totalPairs, 4);
    }

    @Override
    public void init() {
        int i;
        this.PRINT("Initializing... ");
        this.setInputOutput(this.features.length, 1);
        for (i = 0; i < nHiddenLayer; ++i) {
            this.addHiddenLayer(nHiddenNodePerLayer);
        }
        this.wire();
        this.totalPairs = 0;
        for (i = 0; i < this.samples.size(); ++i) {
            RankList rl = ((RankList)this.samples.get(i)).getCorrectRanking();
            for (int j = 0; j < rl.size() - 1; ++j) {
                for (int k = j + 1; k < rl.size(); ++k) {
                    if (!(rl.get(j).getLabel() > rl.get(k).getLabel())) continue;
                    ++this.totalPairs;
                }
            }
        }
        if (this.validationSamples != null) {
            for (i = 0; i < this.layers.size(); ++i) {
                this.bestModelOnValidation.add(new ArrayList());
            }
        }
        Neuron.learningRate = learningRate;
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        this.PRINTLN("-----------------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("--------------------------------------------------");
        this.PRINTLN(new int[]{7, 14, 9, 9}, new String[]{"#epoch", "% mis-ordered", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        this.PRINTLN(new int[]{7, 14, 9, 9}, new String[]{" ", "  pairs", " ", " "});
        this.PRINTLN("--------------------------------------------------");
        for (int i = 1; i <= nIteration; ++i) {
            for (int j = 0; j < this.samples.size(); ++j) {
                RankList rl = this.internalReorder((RankList)this.samples.get(j));
                int[][] pairMap = this.batchFeedForward(rl);
                float[][] pairWeight = this.computePairWeight(pairMap, rl);
                this.batchBackPropagate(pairMap, pairWeight);
                this.clearNeuronOutputs();
            }
            this.scoreOnTrainingData = this.scorer.score(this.rank(this.samples));
            this.estimateLoss();
            this.PRINT(new int[]{7, 14}, new String[]{i + "", SimpleMath.round((double)this.misorderedPairs / (double)this.totalPairs, 4) + ""});
            if (i % 1 == 0) {
                this.PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scoreOnTrainingData, 4) + ""});
                if (this.validationSamples != null) {
                    double score = this.scorer.score(this.rank(this.validationSamples));
                    if (score > this.bestScoreOnValidationData) {
                        this.bestScoreOnValidationData = score;
                        this.saveBestModelOnValidation();
                    }
                    this.PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
                }
            }
            this.PRINTLN("");
        }
        if (this.validationSamples != null) {
            this.restoreBestModelOnValidation();
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("--------------------------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint p) {
        int k;
        for (k = 0; k < this.inputLayer.size() - 1; ++k) {
            this.inputLayer.get(k).setOutput(p.getFeatureValue(this.features[k]));
        }
        this.inputLayer.get(this.inputLayer.size() - 1).setOutput(1.0);
        for (k = 1; k < this.layers.size(); ++k) {
            this.layers.get(k).computeOutput();
        }
        return this.outputLayer.get(0).getOutput();
    }

    @Override
    public Ranker createNew() {
        return new RankNet();
    }

    @Override
    public String toString() {
        String output = "";
        for (int i = 0; i < this.layers.size() - 1; ++i) {
            for (int j = 0; j < this.layers.get(i).size(); ++j) {
                output = output + i + " " + j + " ";
                Neuron n = this.layers.get(i).get(j);
                for (int k = 0; k < n.getOutLinks().size(); ++k) {
                    output = output + n.getOutLinks().get(k).getWeight() + (k == n.getOutLinks().size() - 1 ? "" : " ");
                }
                output = output + "\n";
            }
        }
        return output;
    }

    @Override
    public String model() {
        int i;
        String output = "## " + this.name() + "\n";
        output = output + "## Epochs = " + nIteration + "\n";
        output = output + "## No. of features = " + this.features.length + "\n";
        output = output + "## No. of hidden layers = " + (this.layers.size() - 2) + "\n";
        for (i = 1; i < this.layers.size() - 1; ++i) {
            output = output + "## Layer " + i + ": " + this.layers.get(i).size() + " neurons\n";
        }
        for (i = 0; i < this.features.length; ++i) {
            output = output + this.features[i] + (i == this.features.length - 1 ? "" : " ");
        }
        output = output + "\n";
        output = output + (this.layers.size() - 2) + "\n";
        for (i = 1; i < this.layers.size() - 1; ++i) {
            output = output + this.layers.get(i).size() + "\n";
        }
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            int i;
            String content = "";
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            ArrayList<String> l = new ArrayList<String>();
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                l.add(content);
            }
            in.close();
            String[] tmp = ((String)l.get(0)).split(" ");
            this.features = new int[tmp.length];
            for (int i2 = 0; i2 < tmp.length; ++i2) {
                this.features[i2] = Integer.parseInt(tmp[i2]);
            }
            int nHiddenLayer = Integer.parseInt((String)l.get(1));
            int[] nn = new int[nHiddenLayer];
            for (i = 2; i < 2 + nHiddenLayer; ++i) {
                nn[i - 2] = Integer.parseInt((String)l.get(i));
            }
            this.setInputOutput(this.features.length, 1);
            for (int j = 0; j < nHiddenLayer; ++j) {
                this.addHiddenLayer(nn[j]);
            }
            this.wire();
            while (i < l.size()) {
                String[] s = ((String)l.get(i)).split(" ");
                int iLayer = Integer.parseInt(s[0]);
                int iNeuron = Integer.parseInt(s[1]);
                Neuron n = this.layers.get(iLayer).get(iNeuron);
                for (int k = 0; k < n.getOutLinks().size(); ++k) {
                    n.getOutLinks().get(k).setWeight(Double.parseDouble(s[k + 2]));
                }
                ++i;
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in RankNet::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of epochs: " + nIteration);
        this.PRINTLN("No. of hidden layers: " + nHiddenLayer);
        this.PRINTLN("No. of hidden nodes per layer: " + nHiddenNodePerLayer);
        this.PRINTLN("Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "RankNet";
    }

    protected void printNetworkConfig() {
        for (int i = 1; i < this.layers.size(); ++i) {
            System.out.println("Layer-" + (i + 1));
            for (int j = 0; j < this.layers.get(i).size(); ++j) {
                Neuron n = this.layers.get(i).get(j);
                System.out.print("Neuron-" + (j + 1) + ": " + n.getInLinks().size() + " inputs\t");
                for (int k = 0; k < n.getInLinks().size(); ++k) {
                    System.out.print(n.getInLinks().get(k).getWeight() + "\t");
                }
                System.out.println("");
            }
        }
    }

    protected void printWeightVector() {
        for (int j = 0; j < this.outputLayer.get(0).getInLinks().size(); ++j) {
            System.out.print(this.outputLayer.get(0).getInLinks().get(j).getWeight() + " ");
        }
        System.out.println("");
    }
}

