/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier;

import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;

public class ConfusionMatrix {
    private final Map<String, Integer> labelMap = Maps.newLinkedHashMap();
    private final int[][] confusionMatrix;
    private int samples = 0;
    private String defaultLabel = "unknown";

    public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
        this.confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
        this.defaultLabel = defaultLabel;
        int i = 0;
        for (String label : labels) {
            this.labelMap.put(label, i++);
        }
        this.labelMap.put(defaultLabel, i);
    }

    public ConfusionMatrix(Matrix m) {
        this.confusionMatrix = new int[m.numRows()][m.numRows()];
        this.setMatrix(m);
    }

    public int[][] getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public Collection<String> getLabels() {
        return Collections.unmodifiableCollection(this.labelMap.keySet());
    }

    public double getAccuracy(String label) {
        int labelId = this.labelMap.get(label);
        int labelTotal = 0;
        int correct = 0;
        for (int i = 0; i < this.labelMap.size(); ++i) {
            labelTotal += this.confusionMatrix[labelId][i];
            if (i != labelId) continue;
            correct += this.confusionMatrix[labelId][i];
        }
        return 100.0 * (double)correct / (double)labelTotal;
    }

    public double getAccuracy() {
        int total = 0;
        int correct = 0;
        for (int i = 0; i < this.labelMap.size(); ++i) {
            for (int j = 0; j < this.labelMap.size(); ++j) {
                total += this.confusionMatrix[i][j];
                if (i != j) continue;
                correct += this.confusionMatrix[i][j];
            }
        }
        return 100.0 * (double)correct / (double)total;
    }

    public double getReliability() {
        int count = 0;
        double accuracy = 0.0;
        for (String label : this.labelMap.keySet()) {
            if (!label.equals(this.defaultLabel)) {
                accuracy += this.getAccuracy(label);
            }
            ++count;
        }
        return accuracy / (double)count;
    }

    public double getKappa() {
        double a = 0.0;
        double b = 0.0;
        for (int i = 0; i < this.confusionMatrix.length; ++i) {
            a += (double)this.confusionMatrix[i][i];
            double br = 0.0;
            for (int j = 0; j < this.confusionMatrix.length; ++j) {
                br += (double)this.confusionMatrix[i][j];
            }
            double bc = 0.0;
            for (int[] vec : this.confusionMatrix) {
                bc += (double)vec[i];
            }
            b += br * bc;
        }
        return ((double)this.samples * a - b) / ((double)(this.samples * this.samples) - b);
    }

    public RunningAverageAndStdDev getNormalizedStats() {
        FullRunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
        for (int d = 0; d < this.confusionMatrix.length; ++d) {
            double total = 0.0;
            for (int j = 0; j < this.confusionMatrix.length; ++j) {
                total += (double)this.confusionMatrix[d][j];
            }
            summer.addDatum((double)this.confusionMatrix[d][d] / (total + 1.0E-6));
        }
        return summer;
    }

    public int getCorrect(String label) {
        int labelId = this.labelMap.get(label);
        return this.confusionMatrix[labelId][labelId];
    }

    public int getTotal(String label) {
        int labelId = this.labelMap.get(label);
        int labelTotal = 0;
        for (int i = 0; i < this.labelMap.size(); ++i) {
            labelTotal += this.confusionMatrix[labelId][i];
        }
        return labelTotal;
    }

    public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
        ++this.samples;
        this.incrementCount(correctLabel, classifiedResult.getLabel());
    }

    public void addInstance(String correctLabel, String classifiedLabel) {
        ++this.samples;
        this.incrementCount(correctLabel, classifiedLabel);
    }

    public int getCount(String correctLabel, String classifiedLabel) {
        Preconditions.checkArgument((boolean)this.labelMap.containsKey(correctLabel), (Object)("Label not found: " + correctLabel));
        Preconditions.checkArgument((boolean)this.labelMap.containsKey(classifiedLabel), (Object)("Label not found: " + classifiedLabel));
        int correctId = this.labelMap.get(correctLabel);
        int classifiedId = this.labelMap.get(classifiedLabel);
        return this.confusionMatrix[correctId][classifiedId];
    }

    public void putCount(String correctLabel, String classifiedLabel, int count) {
        Preconditions.checkArgument((boolean)this.labelMap.containsKey(correctLabel), (Object)("Label not found: " + correctLabel));
        Preconditions.checkArgument((boolean)this.labelMap.containsKey(classifiedLabel), (Object)("Label not found: " + classifiedLabel));
        int correctId = this.labelMap.get(correctLabel);
        int classifiedId = this.labelMap.get(classifiedLabel);
        if ((double)this.confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
            ++this.samples;
        }
        this.confusionMatrix[correctId][classifiedId] = count;
    }

    public String getDefaultLabel() {
        return this.defaultLabel;
    }

    public void incrementCount(String correctLabel, String classifiedLabel, int count) {
        this.putCount(correctLabel, classifiedLabel, count + this.getCount(correctLabel, classifiedLabel));
    }

    public void incrementCount(String correctLabel, String classifiedLabel) {
        this.incrementCount(correctLabel, classifiedLabel, 1);
    }

    public ConfusionMatrix merge(ConfusionMatrix b) {
        Preconditions.checkArgument((this.labelMap.size() == b.getLabels().size() ? 1 : 0) != 0, (Object)"The label sizes do not match");
        for (String correctLabel : this.labelMap.keySet()) {
            for (String classifiedLabel : this.labelMap.keySet()) {
                this.incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
            }
        }
        return this;
    }

    public Matrix getMatrix() {
        int length = this.confusionMatrix.length;
        DenseMatrix m = new DenseMatrix(length, length);
        for (int r = 0; r < length; ++r) {
            for (int c = 0; c < length; ++c) {
                m.set(r, c, (double)this.confusionMatrix[r][c]);
            }
        }
        HashMap labels = Maps.newHashMap();
        for (Map.Entry<String, Integer> entry : this.labelMap.entrySet()) {
            labels.put(entry.getKey(), entry.getValue());
        }
        m.setRowLabelBindings((Map)labels);
        m.setColumnLabelBindings((Map)labels);
        return m;
    }

    public void setMatrix(Matrix m) {
        int length = this.confusionMatrix.length;
        if (m.numRows() != m.numCols()) {
            throw new IllegalArgumentException("ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
        }
        for (int r = 0; r < length; ++r) {
            for (int c = 0; c < length; ++c) {
                this.confusionMatrix[r][c] = (int)Math.round(m.get(r, c));
            }
        }
        Map labels = m.getRowLabelBindings();
        if (labels == null) {
            labels = m.getColumnLabelBindings();
        }
        if (labels != null) {
            String[] sorted = ConfusionMatrix.sortLabels(labels);
            ConfusionMatrix.verifyLabels(length, sorted);
            this.labelMap.clear();
            for (int i = 0; i < length; ++i) {
                this.labelMap.put(sorted[i], i);
            }
        }
    }

    private static String[] sortLabels(Map<String, Integer> labels) {
        String[] sorted = new String[labels.size()];
        for (Map.Entry<String, Integer> entry : labels.entrySet()) {
            sorted[entry.getValue().intValue()] = entry.getKey();
        }
        return sorted;
    }

    private static void verifyLabels(int length, String[] sorted) {
        Preconditions.checkArgument((sorted.length == length ? 1 : 0) != 0, (Object)"One label, one row");
        for (int i = 0; i < length; ++i) {
            if (sorted[i] != null) continue;
            Preconditions.checkArgument((boolean)false, (Object)"One label, one row");
        }
    }

    public String toString() {
        StringBuilder returnString = new StringBuilder(200);
        returnString.append("=======================================================").append('\n');
        returnString.append("Confusion Matrix\n");
        returnString.append("-------------------------------------------------------").append('\n');
        int unclassified = this.getTotal(this.defaultLabel);
        for (Map.Entry<String, Integer> entry : this.labelMap.entrySet()) {
            if (entry.getKey().equals(this.defaultLabel) && unclassified == 0) continue;
            returnString.append(StringUtils.rightPad((String)ConfusionMatrix.getSmallLabel(entry.getValue()), (int)5)).append('\t');
        }
        returnString.append("<--Classified as").append('\n');
        for (Map.Entry<String, Integer> entry : this.labelMap.entrySet()) {
            if (entry.getKey().equals(this.defaultLabel) && unclassified == 0) continue;
            String correctLabel = entry.getKey();
            int labelTotal = 0;
            for (String classifiedLabel : this.labelMap.keySet()) {
                if (classifiedLabel.equals(this.defaultLabel) && unclassified == 0) continue;
                returnString.append(StringUtils.rightPad((String)Integer.toString(this.getCount(correctLabel, classifiedLabel)), (int)5)).append('\t');
                labelTotal += this.getCount(correctLabel, classifiedLabel);
            }
            returnString.append(" |  ").append(StringUtils.rightPad((String)String.valueOf(labelTotal), (int)6)).append('\t').append(StringUtils.rightPad((String)ConfusionMatrix.getSmallLabel(entry.getValue()), (int)5)).append(" = ").append(correctLabel).append('\n');
        }
        if (unclassified > 0) {
            returnString.append("Default Category: ").append(this.defaultLabel).append(": ").append(unclassified).append('\n');
        }
        returnString.append('\n');
        return returnString.toString();
    }

    static String getSmallLabel(int i) {
        int val = i;
        StringBuilder returnString = new StringBuilder();
        do {
            int n = val % 26;
            returnString.insert(0, (char)(97 + n));
        } while ((val /= 26) > 0);
        return returnString.toString();
    }
}

