/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFCliqueTree;
import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CRFLogConditionalObjectiveFunction;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.HasCliquePotentialFunction;
import edu.stanford.nlp.ie.crf.LinearCliquePotentialFunction;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

public class CRFLogConditionalObjectiveFunctionForLOP
extends AbstractCachingDiffFunction
implements HasCliquePotentialFunction {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunctionForLOP.class);
    List<Index<CRFLabel>> labelIndices;
    Index<String> classIndex;
    double[][][] Ehat;
    double[] sumOfObservedLogPotential;
    double[][][][][] sumOfExpectedLogPotential;
    List<Set<Integer>> featureIndicesSetArray;
    List<List<Integer>> featureIndicesListArray;
    int window;
    int numClasses;
    int[] map;
    int[][][][] data;
    double[][] lopExpertWeights;
    double[][][] lopExpertWeights2D;
    int[][] labels;
    int[][] learnedParamsMapping;
    int numLopExpert;
    boolean backpropTraining;
    int domainDimension = -1;
    String crfType = "maxent";
    String backgroundSymbol;
    public static boolean VERBOSE = false;

    CRFLogConditionalObjectiveFunctionForLOP(int[][][][] data, int[][] labels, double[][] lopExpertWeights, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol, int numLopExpert, List<Set<Integer>> featureIndicesSetArray, List<List<Integer>> featureIndicesListArray, boolean backpropTraining) {
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.lopExpertWeights = lopExpertWeights;
        this.labels = labels;
        this.backgroundSymbol = backgroundSymbol;
        this.numLopExpert = numLopExpert;
        this.featureIndicesSetArray = featureIndicesSetArray;
        this.featureIndicesListArray = featureIndicesListArray;
        this.backpropTraining = backpropTraining;
        this.initialize2DWeights();
        if (backpropTraining) {
            this.computeEHat();
        } else {
            this.logPotential(this.lopExpertWeights2D);
        }
    }

    @Override
    public int domainDimension() {
        if (this.domainDimension < 0) {
            this.domainDimension = this.numLopExpert;
            if (this.backpropTraining) {
                for (int i = 0; i < this.numLopExpert; ++i) {
                    List<Integer> featureIndicesList = this.featureIndicesListArray.get(i);
                    double[][] expertWeights2D = this.lopExpertWeights2D[i];
                    for (int fIndex : featureIndicesList) {
                        int len = expertWeights2D[fIndex].length;
                        this.domainDimension += len;
                    }
                }
            }
        }
        return this.domainDimension;
    }

    @Override
    public double[] initial() {
        double[] initial = new double[this.domainDimension()];
        if (this.backpropTraining) {
            int index;
            this.learnedParamsMapping = new int[this.domainDimension()][3];
            for (index = 0; index < this.numLopExpert; ++index) {
                initial[index] = 1.0;
            }
            for (int i = 0; i < this.numLopExpert; ++i) {
                List<Integer> featureIndicesList = this.featureIndicesListArray.get(i);
                double[][] expertWeights2D = this.lopExpertWeights2D[i];
                for (int fIndex : featureIndicesList) {
                    int j = 0;
                    while (j < expertWeights2D[fIndex].length) {
                        initial[index] = expertWeights2D[fIndex][j];
                        this.learnedParamsMapping[index] = new int[]{i, fIndex, j++};
                        ++index;
                    }
                }
            }
        } else {
            Arrays.fill(initial, 1.0);
        }
        return initial;
    }

    public double[][][] empty2D() {
        double[][][] d2 = new double[this.numLopExpert][][];
        for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
            double[][] d = new double[this.map.length][];
            for (int i = 0; i < this.map.length; ++i) {
                d[i] = new double[this.labelIndices.get(this.map[i]).size()];
            }
            d2[lopIter] = d;
        }
        return d2;
    }

    private void initialize2DWeights() {
        this.lopExpertWeights2D = new double[this.numLopExpert][][];
        for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
            this.lopExpertWeights2D[lopIter] = this.to2D(this.lopExpertWeights[lopIter], this.labelIndices, this.map);
        }
    }

    public double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            newWeights[i] = new double[labelIndices.get(map[i]).size()];
            System.arraycopy(weights, index, newWeights[i], 0, labelIndices.get(map[i]).size());
            index += labelIndices.get(map[i]).size();
        }
        return newWeights;
    }

    private void computeEHat() {
        this.Ehat = this.empty2D();
        for (int m = 0; m < this.data.length; ++m) {
            int[][][] docData = this.data[m];
            int[] docLabels = this.labels[m];
            int[] windowLabels = new int[this.window];
            Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (int i = 0; i < docData.length; ++i) {
                System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
                windowLabels[this.window - 1] = docLabels[i];
                int[][] docDataI = docData[i];
                for (int j = 0; j < docDataI.length; ++j) {
                    int[] docDataIJ = docDataI[j];
                    int[] cliqueLabel = new int[j + 1];
                    System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                    CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                    Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                    int observedLabelIndex = labelIndex.indexOf(crfLabel);
                    for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
                        double[][] ehatOfIter = this.Ehat[lopIter];
                        Set<Integer> indicesSet = this.featureIndicesSetArray.get(lopIter);
                        for (int featureIdx : docDataIJ) {
                            if (!indicesSet.contains(featureIdx)) continue;
                            double[] dArray = ehatOfIter[featureIdx];
                            int n = observedLabelIndex;
                            dArray[n] = dArray[n] + 1.0;
                        }
                    }
                }
            }
        }
    }

    private void logPotential(double[][][] learnedLopExpertWeights2D) {
        this.sumOfExpectedLogPotential = new double[this.data.length][][][][];
        this.sumOfObservedLogPotential = new double[this.numLopExpert];
        for (int m = 0; m < this.data.length; ++m) {
            int[][][] docData = this.data[m];
            int[] docLabels = this.labels[m];
            int[] windowLabels = new int[this.window];
            Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
            double[][][][] sumOfELPm = new double[docData.length][][][];
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (int i = 0; i < docData.length; ++i) {
                System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
                windowLabels[this.window - 1] = docLabels[i];
                double[][][] sumOfELPmi = new double[docData[i].length][][];
                int[][] docDataI = docData[i];
                for (int j = 0; j < docDataI.length; ++j) {
                    int[] docDataIJ = docDataI[j];
                    int[] cliqueLabel = new int[j + 1];
                    System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                    CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                    Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                    double[][] sumOfELPmij = new double[this.numLopExpert][];
                    int observedLabelIndex = labelIndex.indexOf(crfLabel);
                    for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
                        double[] sumOfELPmijIter = new double[labelIndex.size()];
                        Set<Integer> indicesSet = this.featureIndicesSetArray.get(lopIter);
                        for (int featureIdx : docDataIJ) {
                            if (!indicesSet.contains(featureIdx)) continue;
                            int n = lopIter;
                            this.sumOfObservedLogPotential[n] = this.sumOfObservedLogPotential[n] + learnedLopExpertWeights2D[lopIter][featureIdx][observedLabelIndex];
                            for (int l = 0; l < labelIndex.size(); ++l) {
                                int n2 = l;
                                sumOfELPmijIter[n2] = sumOfELPmijIter[n2] + learnedLopExpertWeights2D[lopIter][featureIdx][l];
                            }
                        }
                        sumOfELPmij[lopIter] = sumOfELPmijIter;
                    }
                    sumOfELPmi[j] = sumOfELPmij;
                }
                sumOfELPm[i] = sumOfELPmi;
            }
            this.sumOfExpectedLogPotential[m] = sumOfELPm;
        }
    }

    public static double[] combineAndScaleLopWeights(int numLopExpert, double[][] lopExpertWeights, double[] lopScales) {
        double[] newWeights = new double[lopExpertWeights[0].length];
        for (int i = 0; i < newWeights.length; ++i) {
            double tempWeight = 0.0;
            for (int lopIter = 0; lopIter < numLopExpert; ++lopIter) {
                tempWeight += lopExpertWeights[lopIter][i] * lopScales[lopIter];
            }
            newWeights[i] = tempWeight;
        }
        return newWeights;
    }

    public static double[][] combineAndScaleLopWeights2D(int numLopExpert, double[][][] lopExpertWeights2D, double[] lopScales) {
        double[][] newWeights = new double[lopExpertWeights2D[0].length][];
        for (int i = 0; i < newWeights.length; ++i) {
            int innerDim = lopExpertWeights2D[0][i].length;
            double[] innerWeights = new double[innerDim];
            for (int j = 0; j < innerDim; ++j) {
                double tempWeight = 0.0;
                for (int lopIter = 0; lopIter < numLopExpert; ++lopIter) {
                    tempWeight += lopExpertWeights2D[lopIter][i][j] * lopScales[lopIter];
                }
                innerWeights[j] = tempWeight;
            }
            newWeights[i] = innerWeights;
        }
        return newWeights;
    }

    public double[][][] separateLopExpertWeights2D(double[] learnedParams) {
        double[][][] learnedWeights2D = this.empty2D();
        for (int paramIndex = this.numLopExpert; paramIndex < learnedParams.length; ++paramIndex) {
            int[] mapping = this.learnedParamsMapping[paramIndex];
            learnedWeights2D[mapping[0]][mapping[1]][mapping[2]] = learnedParams[paramIndex];
        }
        return learnedWeights2D;
    }

    public double[][] separateLopExpertWeights(double[] learnedParams) {
        double[][] learnedWeights = new double[this.numLopExpert][];
        double[][][] learnedWeights2D = this.separateLopExpertWeights2D(learnedParams);
        for (int i = 0; i < this.numLopExpert; ++i) {
            learnedWeights[i] = CRFLogConditionalObjectiveFunction.to1D(learnedWeights2D[i], this.lopExpertWeights[i].length);
        }
        return learnedWeights;
    }

    public double[] separateLopScales(double[] learnedParams) {
        double[] rawScales = new double[this.numLopExpert];
        System.arraycopy(learnedParams, 0, rawScales, 0, this.numLopExpert);
        return rawScales;
    }

    @Override
    public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
        double[] rawScales = this.separateLopScales(x);
        double[] scales = ArrayMath.softmax(rawScales);
        double[][][] learnedLopExpertWeights2D = this.lopExpertWeights2D;
        if (this.backpropTraining) {
            learnedLopExpertWeights2D = this.separateLopExpertWeights2D(x);
        }
        double[][] combinedWeights2D = CRFLogConditionalObjectiveFunctionForLOP.combineAndScaleLopWeights2D(this.numLopExpert, learnedLopExpertWeights2D, scales);
        return new LinearCliquePotentialFunction(combinedWeights2D);
    }

    @Override
    public void calculate(double[] x) {
        double prob = 0.0;
        double[][][] E = this.empty2D();
        double[] eScales = new double[this.numLopExpert];
        double[] rawScales = this.separateLopScales(x);
        double[] scales = ArrayMath.softmax(rawScales);
        double[][][] learnedLopExpertWeights2D = this.lopExpertWeights2D;
        if (this.backpropTraining) {
            learnedLopExpertWeights2D = this.separateLopExpertWeights2D(x);
            this.logPotential(learnedLopExpertWeights2D);
        }
        double[][] combinedWeights2D = CRFLogConditionalObjectiveFunctionForLOP.combineAndScaleLopWeights2D(this.numLopExpert, learnedLopExpertWeights2D, scales);
        for (int m = 0; m < this.data.length; ++m) {
            int i;
            int[][][] docData = this.data[m];
            int[] docLabels = this.labels[m];
            double[][][][] sumOfELPm = this.sumOfExpectedLogPotential[m];
            LinearCliquePotentialFunction cliquePotentialFunc = new LinearCliquePotentialFunction(combinedWeights2D);
            CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol, cliquePotentialFunc, null);
            int[] given = new int[this.window - 1];
            Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
            if (docLabels.length > docData.length) {
                System.arraycopy(docLabels, 0, given, 0, given.length);
                int[] newDocLabels = new int[docData.length];
                System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
                docLabels = newDocLabels;
            }
            for (i = 0; i < docData.length; ++i) {
                int label = docLabels[i];
                double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
                if (VERBOSE) {
                    log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
                }
                prob += p;
                System.arraycopy(given, 1, given, 0, given.length - 1);
                given[given.length - 1] = label;
            }
            for (i = 0; i < docData.length; ++i) {
                double[][][] sumOfELPmi = sumOfELPm[i];
                for (int j = 0; j < docData[i].length; ++j) {
                    double[][] sumOfELPmij = sumOfELPmi[j];
                    Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                    for (int l = 0; l < labelIndex.size(); ++l) {
                        int[] label = labelIndex.get(l).getLabel();
                        double p = cliqueTree.prob(i, label);
                        for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
                            Set<Integer> indicesSet = this.featureIndicesSetArray.get(lopIter);
                            double scale = scales[lopIter];
                            double expected = sumOfELPmij[lopIter][l];
                            for (int innerLopIter = 0; innerLopIter < this.numLopExpert; ++innerLopIter) {
                                expected -= scales[innerLopIter] * sumOfELPmij[innerLopIter][l];
                            }
                            int n = lopIter;
                            eScales[n] = eScales[n] + p * (expected *= scale);
                            double[][] eOfIter = E[lopIter];
                            if (!this.backpropTraining) continue;
                            for (int k = 0; k < docData[i][j].length; ++k) {
                                int featureIdx = docData[i][j][k];
                                if (!indicesSet.contains(featureIdx)) continue;
                                double[] dArray = eOfIter[featureIdx];
                                int n2 = l;
                                dArray[n2] = dArray[n2] + p;
                            }
                        }
                    }
                }
            }
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunctionForLOP.calculate()");
        }
        this.value = -prob;
        if (VERBOSE) {
            log.info("value is " + this.value);
        }
        for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
            double scale = scales[lopIter];
            double observed = this.sumOfObservedLogPotential[lopIter];
            for (int j = 0; j < this.numLopExpert; ++j) {
                observed -= scales[j] * this.sumOfObservedLogPotential[j];
            }
            double expected = eScales[lopIter];
            this.derivative[lopIter] = expected - (observed *= scale);
            if (!VERBOSE) continue;
            log.info("deriv(" + lopIter + ") = " + expected + " - " + observed + " = " + this.derivative[lopIter]);
        }
        if (this.backpropTraining) {
            int dIndex = this.numLopExpert;
            for (int lopIter = 0; lopIter < this.numLopExpert; ++lopIter) {
                double scale = scales[lopIter];
                double[][] eOfExpert = E[lopIter];
                double[][] ehatOfExpert = this.Ehat[lopIter];
                List<Integer> featureIndicesList = this.featureIndicesListArray.get(lopIter);
                for (int fIndex : featureIndicesList) {
                    for (int j = 0; j < eOfExpert[fIndex].length; ++j) {
                        this.derivative[dIndex++] = scale * (eOfExpert[fIndex][j] - ehatOfExpert[fIndex][j]);
                        if (!VERBOSE) continue;
                        log.info("deriv[" + lopIter + "](" + fIndex + "," + j + ") = " + scale + " * (" + eOfExpert[fIndex][j] + " - " + ehatOfExpert[fIndex][j] + ") = " + this.derivative[dIndex - 1]);
                    }
                }
            }
            assert (dIndex == this.domainDimension());
        }
    }
}

