/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class BinarySoftClassification
implements Evaluation {
    public static final ParseField NAME = new ParseField("binary_soft_classification", new String[0]);
    private static final ParseField ACTUAL_FIELD = new ParseField("actual_field", new String[0]);
    private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field", new String[0]);
    private static final ParseField METRICS = new ParseField("metrics", new String[0]);
    public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER = new ConstructingObjectParser(NAME.getPreferredName(), a -> new BinarySoftClassification((String)a[0], (String)a[1], (List)a[2]));
    private final String actualField;
    private final String predictedProbabilityField;
    private final List<SoftClassificationMetric> metrics;

    public static BinarySoftClassification fromXContent(XContentParser parser) {
        return (BinarySoftClassification)PARSER.apply(parser, null);
    }

    public BinarySoftClassification(String actualField, String predictedProbabilityField, @Nullable List<SoftClassificationMetric> metrics) {
        this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
        this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
        this.metrics = BinarySoftClassification.initMetrics(metrics);
    }

    private static List<SoftClassificationMetric> initMetrics(@Nullable List<SoftClassificationMetric> parsedMetrics) {
        List<SoftClassificationMetric> metrics;
        List<SoftClassificationMetric> list = metrics = parsedMetrics == null ? BinarySoftClassification.defaultMetrics() : parsedMetrics;
        if (metrics.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
        }
        Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName));
        return metrics;
    }

    private static List<SoftClassificationMetric> defaultMetrics() {
        ArrayList<SoftClassificationMetric> defaultMetrics = new ArrayList<SoftClassificationMetric>(4);
        defaultMetrics.add(new AucRoc(false));
        defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75)));
        defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75)));
        defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
        return defaultMetrics;
    }

    public BinarySoftClassification(StreamInput in) throws IOException {
        this.actualField = in.readString();
        this.predictedProbabilityField = in.readString();
        this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.actualField);
        out.writeString(this.predictedProbabilityField);
        out.writeNamedWriteableList(this.metrics);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(ACTUAL_FIELD.getPreferredName(), this.actualField);
        builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), this.predictedProbabilityField);
        builder.startObject(METRICS.getPreferredName());
        for (SoftClassificationMetric metric : this.metrics) {
            builder.field(metric.getMetricName(), (ToXContent)metric);
        }
        builder.endObject();
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BinarySoftClassification that = (BinarySoftClassification)o;
        return Objects.equals(this.actualField, that.actualField) && Objects.equals(this.predictedProbabilityField, that.predictedProbabilityField) && Objects.equals(this.metrics, that.metrics);
    }

    public int hashCode() {
        return Objects.hash(this.actualField, this.predictedProbabilityField, this.metrics);
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.existsQuery((String)this.actualField)).filter((QueryBuilder)QueryBuilders.existsQuery((String)this.predictedProbabilityField)).filter(queryBuilder);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query((QueryBuilder)boolQuery);
        for (SoftClassificationMetric metric : this.metrics) {
            List<AggregationBuilder> aggs = metric.aggs(this.actualField, Collections.singletonList(new BinaryClassInfo()));
            aggs.forEach(arg_0 -> ((SearchSourceBuilder)searchSourceBuilder).aggregation(arg_0));
        }
        return searchSourceBuilder;
    }

    @Override
    public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
        if (searchResponse.getHits().getTotalHits().value == 0L) {
            listener.onFailure((Exception)ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", this.actualField, this.predictedProbabilityField));
            return;
        }
        ArrayList<EvaluationMetricResult> results = new ArrayList<EvaluationMetricResult>();
        Aggregations aggs = searchResponse.getAggregations();
        BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
        for (SoftClassificationMetric metric : this.metrics) {
            results.add(metric.evaluate(binaryClassInfo, aggs));
        }
        listener.onResponse(results);
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD);
        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> (SoftClassificationMetric)p.namedObject(SoftClassificationMetric.class, n, null), METRICS);
    }

    private class BinaryClassInfo
    implements SoftClassificationMetric.ClassInfo {
        private QueryBuilder matchingQuery;

        private BinaryClassInfo() {
            this.matchingQuery = QueryBuilders.queryStringQuery((String)(BinarySoftClassification.this.actualField + ": (1 OR true)"));
        }

        @Override
        public String getName() {
            return String.valueOf(true);
        }

        @Override
        public QueryBuilder matchingQuery() {
            return this.matchingQuery;
        }

        @Override
        public String getProbabilityField() {
            return BinarySoftClassification.this.predictedProbabilityField;
        }
    }
}

