/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.loadingservice;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.MessageSupplier;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.RemovalNotification;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class ModelLoadingService
implements ClusterStateListener {
    public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE = Setting.memorySizeSetting((String)"xpack.ml.inference_model.cache_size", (String)"40%", (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL = Setting.timeSetting((String)"xpack.ml.inference_model.time_to_live", (TimeValue)new TimeValue(5L, TimeUnit.MINUTES), (TimeValue)new TimeValue(1L, TimeUnit.MILLISECONDS), (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
    private final TrainedModelStatsService modelStatsService;
    private final Cache<String, ModelAndConsumer> localModelCache;
    private final Set<String> referencedModels = new HashSet<String>();
    private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<String, Queue<ActionListener<LocalModel>>>();
    private final TrainedModelProvider provider;
    private final Set<String> shouldNotAudit;
    private final ThreadPool threadPool;
    private final InferenceAuditor auditor;
    private final ByteSizeValue maxCacheSize;
    private final String localNode;
    private final CircuitBreaker trainedModelCircuitBreaker;

    public ModelLoadingService(TrainedModelProvider trainedModelProvider, InferenceAuditor auditor, ThreadPool threadPool, ClusterService clusterService, TrainedModelStatsService modelStatsService, Settings settings, String localNode, CircuitBreaker trainedModelCircuitBreaker) {
        this.provider = trainedModelProvider;
        this.threadPool = threadPool;
        this.maxCacheSize = (ByteSizeValue)INFERENCE_MODEL_CACHE_SIZE.get(settings);
        this.auditor = auditor;
        this.modelStatsService = modelStatsService;
        this.shouldNotAudit = new HashSet<String>();
        this.localModelCache = CacheBuilder.builder().setMaximumWeight(this.maxCacheSize.getBytes()).weigher((id, modelAndConsumer) -> ((ModelAndConsumer)modelAndConsumer).model.ramBytesUsed()).removalListener(notification -> this.cacheEvictionListener((RemovalNotification<String, ModelAndConsumer>)notification)).setExpireAfterAccess((TimeValue)INFERENCE_MODEL_CACHE_TTL.get(settings)).build();
        clusterService.addListener((ClusterStateListener)this);
        this.localNode = localNode;
        this.trainedModelCircuitBreaker = (CircuitBreaker)ExceptionsHelper.requireNonNull((Object)trainedModelCircuitBreaker, (String)"trainedModelCircuitBreaker");
    }

    boolean isModelCached(String modelId) {
        return this.localModelCache.get((Object)modelId) != null;
    }

    public void getModelForPipeline(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.PIPELINE, modelActionListener);
    }

    public void getModelForSearch(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.SEARCH, modelActionListener);
    }

    private void getModel(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
        ModelAndConsumer cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
        if (cachedModel != null) {
            cachedModel.consumers.add(consumer);
            try {
                cachedModel.model.acquire();
            }
            catch (CircuitBreakingException e) {
                modelActionListener.onFailure((Exception)((Object)e));
                return;
            }
            modelActionListener.onResponse((Object)cachedModel.model);
            logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", (Object)modelId));
            return;
        }
        if (this.loadModelIfNecessary(modelId, consumer, modelActionListener)) {
            logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", (Object)modelId));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            ModelAndConsumer cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
            if (cachedModel != null) {
                cachedModel.consumers.add(consumer);
                try {
                    cachedModel.model.acquire();
                }
                catch (CircuitBreakingException e) {
                    modelActionListener.onFailure((Exception)((Object)e));
                    return true;
                }
                modelActionListener.onResponse((Object)cachedModel.model);
                return true;
            }
            Queue listeners = this.loadingListeners.computeIfPresent(modelId, (storedModelKey, listenerQueue) -> ModelLoadingService.addFluently(listenerQueue, modelActionListener));
            if (listeners != null) {
                return true;
            }
            if (Consumer.PIPELINE == consumer && !this.referencedModels.contains(modelId)) {
                this.loadWithoutCaching(modelId, modelActionListener);
            } else {
                logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", (Object)modelId));
                this.loadingListeners.put(modelId, ModelLoadingService.addFluently(new ArrayDeque(), modelActionListener));
                this.loadModel(modelId, consumer);
            }
            return false;
        }
    }

    private void loadModel(String modelId, Consumer consumer) {
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
            this.provider.getTrainedModelForInference(modelId, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    this.handleLoadFailure(modelId, (Exception)((Object)ex));
                    return;
                }
                this.handleLoadSuccess(modelId, consumer, (TrainedModelConfig)trainedModelConfig, (InferenceDefinition)inferenceDefinition);
            }, failure -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                logger.warn((Message)new ParameterizedMessage("[{}] failed to load model definition", (Object)modelId), (Throwable)failure);
                this.handleLoadFailure(modelId, (Exception)failure);
            }));
        }, failure -> {
            logger.warn((Message)new ParameterizedMessage("[{}] failed to load model configuration", (Object)modelId), (Throwable)failure);
            this.handleLoadFailure(modelId, (Exception)failure);
        }));
    }

    private void loadWithoutCaching(String modelId, ActionListener<LocalModel> modelActionListener) {
        logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", (Object)modelId));
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
            this.provider.getTrainedModelForInference(modelId, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    modelActionListener.onFailure((Exception)((Object)ex));
                    return;
                }
                modelActionListener.onResponse((Object)new LocalModel(trainedModelConfig.getModelId(), this.localNode, (InferenceDefinition)inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), this.modelStatsService, this.trainedModelCircuitBreaker));
            }, e -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                modelActionListener.onFailure(e);
            }));
        }, arg_0 -> modelActionListener.onFailure(arg_0)));
    }

    private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition, TrainedModelConfig trainedModelConfig) throws CircuitBreakingException {
        long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
        if (estimateDiff < 0L) {
            this.trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
        } else if (estimateDiff > 0L) {
            try {
                this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
            }
            catch (CircuitBreakingException ex) {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                throw ex;
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadSuccess(String modelId, Consumer consumer, TrainedModelConfig trainedModelConfig, InferenceDefinition inferenceDefinition) {
        Queue<ActionListener<LocalModel>> listeners;
        InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
        LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), this.localNode, inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), this.modelStatsService, this.trainedModelCircuitBreaker);
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            listeners = this.loadingListeners.remove(modelId);
            if (listeners == null) {
                loadedModel.release();
                return;
            }
            loadedModel.acquire();
            this.localModelCache.put((Object)modelId, (Object)new ModelAndConsumer(loadedModel, consumer));
            this.shouldNotAudit.remove(modelId);
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            loadedModel.acquire();
            listener.onResponse((Object)loadedModel);
            listener = listeners.poll();
        }
        loadedModel.release();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadFailure(String modelId, Exception failure) {
        Queue<ActionListener<LocalModel>> listeners;
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            listeners = this.loadingListeners.remove(modelId);
            if (listeners == null) {
                return;
            }
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            listener.onFailure(failure);
            listener = listeners.poll();
        }
    }

    private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
        try {
            if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
                MessageSupplier msg = () -> new ParameterizedMessage("model cache entry evicted.current cache [{}] current max [{}] model size [{}]. If this is undesired, consider updating setting [{}] or [{}].", new Object[]{new ByteSizeValue(this.localModelCache.weight()).getStringRep(), this.maxCacheSize.getStringRep(), new ByteSizeValue(((ModelAndConsumer)notification.getValue()).model.ramBytesUsed()).getStringRep(), INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_TTL.getKey()});
                this.auditIfNecessary((String)notification.getKey(), msg);
            }
            logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", (Object)((ModelAndConsumer)notification.getValue()).model.getModelId()));
            ((ModelAndConsumer)notification.getValue()).model.persistStats(!this.referencedModels.contains(notification.getKey()));
        }
        finally {
            ((ModelAndConsumer)notification.getValue()).model.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void clusterChanged(ClusterChangedEvent event) {
        if (!event.changedCustomMetadataSet().contains("ingest") || !event.state().nodes().getLocalNode().isIngestNode()) {
            return;
        }
        ClusterState state = event.state();
        IngestMetadata currentIngestMetadata = (IngestMetadata)state.metadata().custom("ingest");
        Set<String> allReferencedModelKeys = ModelLoadingService.getReferencedModelKeys(currentIngestMetadata);
        if (allReferencedModelKeys.equals(this.referencedModels)) {
            return;
        }
        ArrayList<Tuple> drainWithFailure = new ArrayList<Tuple>();
        HashSet<String> referencedModelsBeforeClusterState = null;
        HashSet<String> loadingModelBeforeClusterState = null;
        Set removedModels = null;
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            referencedModelsBeforeClusterState = new HashSet<String>(this.referencedModels);
            if (logger.isTraceEnabled()) {
                loadingModelBeforeClusterState = new HashSet<String>(this.loadingListeners.keySet());
            }
            for (String modelId2 : this.loadingListeners.keySet()) {
                if (allReferencedModelKeys.contains(modelId2)) continue;
                drainWithFailure.add(Tuple.tuple((Object)modelId2, new ArrayList(this.loadingListeners.remove(modelId2))));
            }
            removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
            removedModels.forEach(modelId -> {
                ModelAndConsumer modelAndConsumer = (ModelAndConsumer)this.localModelCache.get(modelId);
                if (modelAndConsumer != null && !modelAndConsumer.consumers.contains((Object)Consumer.SEARCH)) {
                    this.localModelCache.invalidate(modelId);
                }
            });
            this.referencedModels.removeAll(removedModels);
            this.shouldNotAudit.removeAll(removedModels);
            allReferencedModelKeys.removeAll(this.referencedModels);
            this.referencedModels.addAll(allReferencedModelKeys);
            for (String modelId2 : allReferencedModelKeys) {
                this.loadingListeners.computeIfAbsent(modelId2, s -> new ArrayDeque());
            }
        }
        if (logger.isTraceEnabled()) {
            if (!this.loadingListeners.keySet().equals(loadingModelBeforeClusterState)) {
                logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, this.loadingListeners.keySet());
            }
            if (!this.referencedModels.equals(referencedModelsBeforeClusterState)) {
                logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, this.referencedModels);
            }
        }
        for (Tuple modelAndListeners : drainWithFailure) {
            String msg = new ParameterizedMessage("Cancelling load of model [{}] as it is no longer referenced by a pipeline", modelAndListeners.v1()).getFormat();
            for (ActionListener listener : (List)modelAndListeners.v2()) {
                listener.onFailure((Exception)((Object)new ElasticsearchException(msg, new Object[0])));
            }
        }
        removedModels.forEach(this::auditUnreferencedModel);
        this.loadModelsForPipeline(allReferencedModelKeys);
    }

    private void auditIfNecessary(String modelId, MessageSupplier msg) {
        if (this.shouldNotAudit.contains(modelId)) {
            logger.trace(() -> new ParameterizedMessage("[{}] {}", (Object)modelId, (Object)msg.get().getFormattedMessage()));
            return;
        }
        this.auditor.info(modelId, msg.get().getFormattedMessage());
        this.shouldNotAudit.add(modelId);
        logger.info("[{}] {}", (Object)modelId, (Object)msg.get().getFormattedMessage());
    }

    private void loadModelsForPipeline(Set<String> modelIds) {
        if (modelIds.isEmpty()) {
            return;
        }
        this.threadPool.executor("ml_utility").execute(() -> {
            for (String modelId : modelIds) {
                this.auditNewReferencedModel(modelId);
                this.loadModel(modelId, Consumer.PIPELINE);
            }
        });
    }

    private void auditNewReferencedModel(String modelId) {
        this.auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache");
    }

    private void auditUnreferencedModel(String modelId) {
        this.auditor.info(modelId, "no longer referenced by any processors");
    }

    private static <T> Queue<T> addFluently(Queue<T> queue, T object) {
        queue.add(object);
        return queue;
    }

    private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
        HashSet<String> allReferencedModelKeys = new HashSet<String>();
        if (ingestMetadata == null) {
            return allReferencedModelKeys;
        }
        ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
            Object processors = pipelineConfiguration.getConfigAsMap().get("processors");
            if (processors instanceof List) {
                for (Object processor : (List)processors) {
                    Object modelId;
                    Object processorConfig;
                    if (!(processor instanceof Map) || !((processorConfig = ((Map)processor).get("inference")) instanceof Map) || (modelId = ((Map)processorConfig).get("model_id")) == null) continue;
                    assert (modelId instanceof String);
                    allReferencedModelKeys.add(modelId.toString());
                }
            }
        });
        return allReferencedModelKeys;
    }

    private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) {
        switch (targetType) {
            case REGRESSION: {
                return RegressionConfig.EMPTY_PARAMS;
            }
            case CLASSIFICATION: {
                return ClassificationConfig.EMPTY_PARAMS;
            }
        }
        throw ExceptionsHelper.badRequestException((String)"unsupported target type [{}]", (Object[])new Object[]{targetType});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void addModelLoadedListener(String modelId, ActionListener<LocalModel> modelLoadedListener) {
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            this.loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
                if (listenerQueue == null) {
                    return ModelLoadingService.addFluently(new ArrayDeque(), modelLoadedListener);
                }
                return ModelLoadingService.addFluently(listenerQueue, modelLoadedListener);
            });
        }
    }

    public static enum Consumer {
        PIPELINE,
        SEARCH;

    }

    private static class ModelAndConsumer {
        private final LocalModel model;
        private final EnumSet<Consumer> consumers;

        private ModelAndConsumer(LocalModel model, Consumer consumer) {
            this.model = model;
            this.consumers = EnumSet.of(consumer);
        }
    }
}

