/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AbstractConfusionMatrixMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;

public class ConfusionMatrix
extends AbstractConfusionMatrixMetric {
    public static final ParseField NAME = new ParseField("confusion_matrix", new String[0]);
    private static final ConstructingObjectParser<ConfusionMatrix, Void> PARSER = new ConstructingObjectParser(NAME.getPreferredName(), a -> new ConfusionMatrix((List)a[0]));

    public static ConfusionMatrix fromXContent(XContentParser parser) {
        return (ConfusionMatrix)PARSER.apply(parser, null);
    }

    public ConfusionMatrix(List<Double> at) {
        super(at);
    }

    public ConfusionMatrix(StreamInput in) throws IOException {
        super(in);
    }

    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(OutlierDetection.NAME, NAME);
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ConfusionMatrix that = (ConfusionMatrix)o;
        return Arrays.equals(this.thresholds, that.thresholds);
    }

    public int hashCode() {
        return Arrays.hashCode(this.thresholds);
    }

    @Override
    protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
        ArrayList<AggregationBuilder> aggs = new ArrayList<AggregationBuilder>();
        for (int i = 0; i < this.thresholds.length; ++i) {
            double threshold = this.thresholds[i];
            aggs.add(this.buildAgg(actualField, predictedProbabilityField, threshold, AbstractConfusionMatrixMetric.Condition.TP));
            aggs.add(this.buildAgg(actualField, predictedProbabilityField, threshold, AbstractConfusionMatrixMetric.Condition.FP));
            aggs.add(this.buildAgg(actualField, predictedProbabilityField, threshold, AbstractConfusionMatrixMetric.Condition.TN));
            aggs.add(this.buildAgg(actualField, predictedProbabilityField, threshold, AbstractConfusionMatrixMetric.Condition.FN));
        }
        return aggs;
    }

    @Override
    public EvaluationMetricResult evaluate(Aggregations aggs) {
        long[] tp = new long[this.thresholds.length];
        long[] fp = new long[this.thresholds.length];
        long[] tn = new long[this.thresholds.length];
        long[] fn = new long[this.thresholds.length];
        for (int i = 0; i < this.thresholds.length; ++i) {
            Filter tpAgg = (Filter)aggs.get(this.aggName(this.thresholds[i], AbstractConfusionMatrixMetric.Condition.TP));
            Filter fpAgg = (Filter)aggs.get(this.aggName(this.thresholds[i], AbstractConfusionMatrixMetric.Condition.FP));
            Filter tnAgg = (Filter)aggs.get(this.aggName(this.thresholds[i], AbstractConfusionMatrixMetric.Condition.TN));
            Filter fnAgg = (Filter)aggs.get(this.aggName(this.thresholds[i], AbstractConfusionMatrixMetric.Condition.FN));
            tp[i] = tpAgg.getDocCount();
            fp[i] = fpAgg.getDocCount();
            tn[i] = tnAgg.getDocCount();
            fn[i] = fnAgg.getDocCount();
        }
        return new Result(this.thresholds, tp, fp, tn, fn);
    }

    static {
        PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT);
    }

    public static class Result
    implements EvaluationMetricResult {
        private final double[] thresholds;
        private final long[] tp;
        private final long[] fp;
        private final long[] tn;
        private final long[] fn;

        public Result(double[] thresholds, long[] tp, long[] fp, long[] tn, long[] fn) {
            assert (thresholds.length == tp.length);
            assert (thresholds.length == fp.length);
            assert (thresholds.length == tn.length);
            assert (thresholds.length == fn.length);
            this.thresholds = thresholds;
            this.tp = tp;
            this.fp = fp;
            this.tn = tn;
            this.fn = fn;
        }

        public Result(StreamInput in) throws IOException {
            this.thresholds = in.readDoubleArray();
            this.tp = in.readLongArray();
            this.fp = in.readLongArray();
            this.tn = in.readLongArray();
            this.fn = in.readLongArray();
        }

        public String getWriteableName() {
            return MlEvaluationNamedXContentProvider.registeredMetricName(OutlierDetection.NAME, NAME);
        }

        @Override
        public String getMetricName() {
            return NAME.getPreferredName();
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDoubleArray(this.thresholds);
            out.writeLongArray(this.tp);
            out.writeLongArray(this.fp);
            out.writeLongArray(this.tn);
            out.writeLongArray(this.fn);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            for (int i = 0; i < this.thresholds.length; ++i) {
                builder.startObject(String.valueOf(this.thresholds[i]));
                builder.field("tp", this.tp[i]);
                builder.field("fp", this.fp[i]);
                builder.field("tn", this.tn[i]);
                builder.field("fn", this.fn[i]);
                builder.endObject();
            }
            builder.endObject();
            return builder;
        }
    }
}

