/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.processor.InferenceProcessorAttributes;
import org.opensearch.ml.processor.ModelExecutor;
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;
import org.opensearch.transport.client.Client;

public class MLInferenceIngestProcessor
extends AbstractProcessor
implements ModelExecutor {
    private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class);
    public static final String DOT_SYMBOL = ".";
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private final boolean fullResponsePath;
    private final boolean ignoreFailure;
    private final boolean override;
    private final String modelInput;
    private final ScriptService scriptService;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String OVERRIDE = "override";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
    public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    private final NamedXContentRegistry xContentRegistry;

    protected MLInferenceIngestProcessor(String modelId, List<Map<String, String>> inputMaps, List<Map<String, String>> outputMaps, Map<String, String> modelConfigMaps, int maxPredictionTask, String tag, String description, boolean ignoreMissing, String functionName, boolean fullResponsePath, boolean ignoreFailure, boolean override, String modelInput, ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
        super(tag, description);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask);
        this.ignoreMissing = ignoreMissing;
        this.functionName = functionName;
        this.fullResponsePath = fullResponsePath;
        this.ignoreFailure = ignoreFailure;
        this.override = override;
        this.modelInput = modelInput;
        this.scriptService = scriptService;
        MLInferenceIngestProcessor.client = client;
        this.xContentRegistry = xContentRegistry;
    }

    public void execute(final IngestDocument ingestDocument, final BiConsumer<IngestDocument, Exception> handler) {
        List<Map<String, String>> processInputMap = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> processOutputMap = this.inferenceProcessorAttributes.getOutputMaps();
        int inputMapSize = processInputMap != null ? processInputMap.size() : 0;
        GroupedActionListener batchPredictionListener = new GroupedActionListener((ActionListener)new ActionListener<Collection<Void>>(){

            public void onResponse(Collection<Void> voids) {
                handler.accept(ingestDocument, null);
            }

            public void onFailure(Exception e) {
                if (MLInferenceIngestProcessor.this.ignoreFailure) {
                    handler.accept(ingestDocument, null);
                } else {
                    handler.accept(null, e);
                }
            }
        }, Math.max(inputMapSize, 1));
        for (int inputMapIndex = 0; inputMapIndex < Math.max(inputMapSize, 1); ++inputMapIndex) {
            try {
                this.processPredictions(ingestDocument, (GroupedActionListener<Void>)batchPredictionListener, processInputMap, processOutputMap, inputMapIndex, inputMapSize);
                continue;
            }
            catch (Exception e) {
                batchPredictionListener.onFailure(e);
            }
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
        throw new UnsupportedOperationException("this method should not get executed.");
    }

    private void processPredictions(final IngestDocument ingestDocument, final GroupedActionListener<Void> batchPredictionListener, List<Map<String, String>> processInputMap, final List<Map<String, String>> processOutputMap, final int inputMapIndex, int inputMapSize) throws IOException {
        HashMap<String, String> modelParameters = new HashMap<String, String>();
        HashMap<String, String> modelConfigs = new HashMap<String, String>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            modelParameters.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            modelConfigs.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        HashMap<String, Map> ingestDocumentSourceAndMetaData = new HashMap<String, Map>();
        ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
        ingestDocumentSourceAndMetaData.put("_ingest", ingestDocument.getIngestMetadata());
        final HashMap<String, List> newOutputMapping = new HashMap<String, List>();
        if (processOutputMap != null) {
            String newDocumentFieldName;
            Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);
            for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
                newDocumentFieldName = entry.getKey();
                List dotPathsInArray = this.writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
                newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
            }
            for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
                newDocumentFieldName = entry.getKey();
                List dotPaths = (List)newOutputMapping.get(newDocumentFieldName);
                int existingFields = 0;
                for (String path : dotPaths) {
                    if (!ingestDocument.hasField(path)) continue;
                    ++existingFields;
                }
                if (this.override || existingFields != dotPaths.size()) continue;
                logger.debug("{} already exists in the ingest document. Removing it from output mapping", (Object)newDocumentFieldName);
                newOutputMapping.remove(newDocumentFieldName);
            }
            if (newOutputMapping.size() == 0) {
                batchPredictionListener.onResponse(null);
                return;
            }
        }
        if (inputMapSize == 0) {
            Set documentFields = ingestDocument.getSourceAndMetadata().keySet();
            for (String string : documentFields) {
                this.getMappedModelInputFromDocuments(ingestDocument, modelParameters, string, string);
            }
        } else {
            Map<String, String> inputMapping = processInputMap.get(inputMapIndex);
            for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
                String modelInputFieldName = entry.getKey();
                String documentFieldName = entry.getValue();
                this.getMappedModelInputFromDocuments(ingestDocument, modelParameters, documentFieldName, modelInputFieldName);
            }
        }
        HashSet inputMapKeys = new HashSet(modelParameters.keySet());
        inputMapKeys.removeAll(modelConfigs.keySet());
        HashMap<String, String> inputMappings = new HashMap<String, String>();
        for (String k : inputMapKeys) {
            inputMappings.put(k, (String)modelParameters.get(k));
        }
        ActionRequest actionRequest = this.getMLModelInferenceRequest(this.xContentRegistry, modelParameters, modelConfigs, inputMappings, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput);
        client.execute((ActionType)MLPredictionTaskAction.INSTANCE, actionRequest, (ActionListener)new ActionListener<MLTaskResponse>(){

            public void onResponse(MLTaskResponse mlTaskResponse) {
                MLOutput mlOutput = mlTaskResponse.getOutput();
                if (processOutputMap == null || processOutputMap.isEmpty()) {
                    MLInferenceIngestProcessor.this.appendFieldValue(mlOutput, null, MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
                } else {
                    Map outputMapping = (Map)processOutputMap.get(inputMapIndex);
                    for (Map.Entry entry : outputMapping.entrySet()) {
                        String newDocumentFieldName = (String)entry.getKey();
                        String modelOutputFieldName = (String)entry.getValue();
                        if (!newOutputMapping.containsKey(newDocumentFieldName)) continue;
                        MLInferenceIngestProcessor.this.appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
                    }
                }
                batchPredictionListener.onResponse(null);
            }

            public void onFailure(Exception e) {
                batchPredictionListener.onFailure(e);
            }
        });
    }

    private void getMappedModelInputFromDocuments(IngestDocument ingestDocument, Map<String, String> modelParameters, String documentFieldName, String modelInputFieldName) {
        String originalFieldPath = this.getFieldPath(ingestDocument, documentFieldName);
        if (originalFieldPath != null) {
            Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
            String documentFieldValueAsString = this.toString(documentFieldValue);
            this.updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
            return;
        }
        if (StringUtils.isValidJSONPath((String)documentFieldName)) {
            Map sourceObject = ingestDocument.getSourceAndMetadata();
            Object fieldValue = JsonPath.using((Configuration)suppressExceptionConfiguration).parse((Object)sourceObject).read(documentFieldName, new Predicate[0]);
            if (fieldValue != null) {
                if (fieldValue instanceof List) {
                    List fieldValueList = (List)fieldValue;
                    if (!fieldValueList.isEmpty()) {
                        this.updateModelParameters(modelInputFieldName, this.toString(fieldValueList), modelParameters);
                    } else if (!this.ignoreMissing) {
                        throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
                    }
                } else {
                    this.updateModelParameters(modelInputFieldName, this.toString(fieldValue), modelParameters);
                }
            } else if (!this.ignoreMissing) {
                throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
            }
        } else {
            throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
        }
    }

    private void updateModelParameters(String modelInputFieldName, String originalFieldValueAsString, Map<String, String> modelParameters) {
        if (modelParameters.containsKey(modelInputFieldName)) {
            String existingValue = modelParameters.get(modelInputFieldName);
            List updatedList = (List)((Object)existingValue);
            updatedList.add(originalFieldValueAsString);
            modelParameters.put(modelInputFieldName, this.toString(updatedList));
        } else {
            modelParameters.put(modelInputFieldName, originalFieldValueAsString);
        }
    }

    private String getFieldPath(IngestDocument ingestDocument, String documentFieldName) {
        if (Strings.isNullOrEmpty((String)documentFieldName) || !ingestDocument.hasField(documentFieldName, true)) {
            return null;
        }
        return documentFieldName;
    }

    private void appendFieldValue(MLOutput mlOutput, String modelOutputFieldName, String newDocumentFieldName, IngestDocument ingestDocument) {
        if (mlOutput == null) {
            throw new RuntimeException("model inference output is null");
        }
        Object modelOutputValue = this.getModelOutputValue(mlOutput, modelOutputFieldName, this.ignoreMissing, this.fullResponsePath);
        HashMap<String, Map> ingestDocumentSourceAndMetaData = new HashMap<String, Map>();
        ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
        ingestDocumentSourceAndMetaData.put("_ingest", ingestDocument.getIngestMetadata());
        List dotPathsInArray = this.writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
        if (dotPathsInArray.size() == 1) {
            ValueSource ingestValue = ValueSource.wrap((Object)modelOutputValue, (ScriptService)this.scriptService);
            TemplateScript.Factory ingestField = ConfigurationUtils.compileTemplate((String)TYPE, (String)this.tag, (String)((String)dotPathsInArray.get(0)), (String)((String)dotPathsInArray.get(0)), (ScriptService)this.scriptService);
            ingestDocument.setFieldValue(ingestField, ingestValue, this.ignoreMissing);
        } else {
            if (!(modelOutputValue instanceof List)) {
                throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
            }
            List modelOutputValueArray = (List)modelOutputValue;
            if (dotPathsInArray.size() != modelOutputValueArray.size()) {
                throw new RuntimeException("the prediction field: " + modelOutputFieldName + " is an array in size of " + modelOutputValueArray.size() + " but the document field array from field " + newDocumentFieldName + " is in size of " + dotPathsInArray.size());
            }
            for (int i = 0; i < dotPathsInArray.size(); ++i) {
                String dotPathInArray = (String)dotPathsInArray.get(i);
                Object modelOutputValueInArray = modelOutputValueArray.get(i);
                ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, (ScriptService)this.scriptService);
                TemplateScript.Factory ingestField = ConfigurationUtils.compileTemplate((String)TYPE, (String)this.tag, (String)dotPathInArray, (String)dotPathInArray, (ScriptService)this.scriptService);
                ingestDocument.setFieldValue(ingestField, ingestValue, this.ignoreMissing);
            }
        }
    }

    public String getType() {
        return TYPE;
    }

    public static class Factory
    implements Processor.Factory {
        private final ScriptService scriptService;
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
            this.scriptService = scriptService;
            this.client = client;
            this.xContentRegistry = xContentRegistry;
        }

        public MLInferenceIngestProcessor create(Map<String, Processor.Factory> registry, String processorTag, String description, Map<String, Object> config) throws Exception {
            String modelId = ConfigurationUtils.readStringProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"model_id");
            Map modelConfigInput = ConfigurationUtils.readOptionalMap((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"model_config");
            List inputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"input_map");
            List outputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"output_map");
            int maxPredictionTask = ConfigurationUtils.readIntProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"max_prediction_tasks", (Integer)10);
            boolean ignoreMissing = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.IGNORE_MISSING, (boolean)false);
            boolean override = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.OVERRIDE, (boolean)false);
            String functionName = ConfigurationUtils.readStringProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.FUNCTION_NAME, (String)FunctionName.REMOTE.name());
            String modelInput = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.MODEL_INPUT);
            if (functionName.equalsIgnoreCase("remote")) {
                modelInput = modelInput != null ? modelInput : MLInferenceIngestProcessor.DEFAULT_MODEl_INPUT;
            } else if (modelInput == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
            boolean fullResponsePath = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.FULL_RESPONSE_PATH, (boolean)defaultFullResponsePath);
            boolean ignoreFailure = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"ignore_failure", (boolean)false);
            Map modelConfigMaps = null;
            if (modelConfigInput != null) {
                modelConfigMaps = StringUtils.getParameterMap((Map)modelConfigInput);
            }
            if (inputMaps != null && inputMaps.size() > maxPredictionTask) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + inputMaps.size() + ". It exceeds the max_prediction_tasks of " + maxPredictionTask + ". Please reduce the size of input_map or increase max_prediction_tasks.");
            }
            if (inputMaps != null && outputMaps != null && outputMaps.size() != inputMaps.size()) {
                throw new IllegalArgumentException("The length of output_map and the length of input_map do no match.");
            }
            return new MLInferenceIngestProcessor(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask, processorTag, description, ignoreMissing, functionName, fullResponsePath, ignoreFailure, override, modelInput, this.scriptService, this.client, this.xContentRegistry);
        }
    }
}

