/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.mcpserver;

import com.google.common.collect.ImmutableMap;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.spec.McpSchema;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.Generated;
import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.action.mcpserver.McpAsyncServerHolder;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.mcpserver.requests.McpToolBaseInput;
import org.opensearch.ml.common.transport.mcpserver.requests.register.McpToolRegisterInput;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.rest.mcpserver.ToolFactoryWrapper;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.Client;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

public class McpToolsHelper {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(McpToolsHelper.class);
    public static final int MAX_TOOL_NUMBER = 1000;
    private static final int SYNC_MCP_TOOLS_JOB_INTERVAL = 10;
    private final Client client;
    private final ThreadPool threadPool;
    private final ToolFactoryWrapper toolFactoryWrapper;

    public McpToolsHelper(Client client, ThreadPool threadPool, ToolFactoryWrapper toolFactoryWrapper) {
        this.client = client;
        this.threadPool = threadPool;
        this.toolFactoryWrapper = toolFactoryWrapper;
    }

    public void autoLoadAllMcpTools(ActionListener<Boolean> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener searchListener = ActionListener.wrap(r -> {
                r.forEach((key, value) -> {
                    if (!McpAsyncServerHolder.IN_MEMORY_MCP_TOOLS.containsKey(key)) {
                        McpAsyncServerHolder.getMcpAsyncServerInstance().addTool(this.createToolSpecification((McpToolBaseInput)value.v1())).doOnSuccess(y -> McpAsyncServerHolder.IN_MEMORY_MCP_TOOLS.put((String)key, (Long)value.v2())).doOnError(x -> log.error("Failed to auto load tool: {}", (Object)((McpToolRegisterInput)value.v1()).getName(), x)).subscribe();
                    } else if (McpAsyncServerHolder.IN_MEMORY_MCP_TOOLS.get(key) < (Long)value.v2()) {
                        McpAsyncServerHolder.getMcpAsyncServerInstance().removeTool(key).onErrorResume(e -> Mono.empty()).subscribe();
                        McpAsyncServerHolder.getMcpAsyncServerInstance().addTool(this.createToolSpecification((McpToolBaseInput)value.v1())).doOnSuccess(x -> McpAsyncServerHolder.IN_MEMORY_MCP_TOOLS.put((String)key, (Long)value.v2())).doOnError(x -> log.error("Failed to auto load tool: {}", (Object)((McpToolRegisterInput)value.v1()).getName(), x)).subscribe();
                    }
                });
                this.startSyncMcpToolsJob();
                restoreListener.onResponse((Object)true);
            }, e -> {
                log.error("Failed to auto load all MCP tools to MCP server", (Throwable)e);
                restoreListener.onFailure(e);
            });
            this.searchAllToolsWithVersion((ActionListener<Map<String, Tuple<McpToolRegisterInput, Long>>>)searchListener);
        }
        catch (Exception e2) {
            log.error("Failed to auto load all MCP tools to MCP server", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    public void startSyncMcpToolsJob() {
        ActionListener listener = ActionListener.wrap(r -> log.debug("Auto reload mcp tools schedule job run successfully!"), e -> log.error(e.getMessage(), (Throwable)e));
        this.threadPool.schedule(() -> this.autoLoadAllMcpTools((ActionListener<Boolean>)listener), TimeValue.timeValueSeconds((long)10L), "opensearch_ml_general");
    }

    public McpServerFeatures.AsyncToolSpecification createToolSpecification(McpToolBaseInput tool) {
        String toolName = Optional.ofNullable(tool.getName()).orElse(tool.getType());
        Tool.Factory factory = this.toolFactoryWrapper.getToolsFactories().get(tool.getType());
        if (factory == null) {
            throw new OpenSearchException("Failed to find tool factory for tool type: " + tool.getType(), new Object[0]);
        }
        Tool actualTool = factory.create(Optional.ofNullable(tool.getParameters()).orElse((Map)ImmutableMap.of()));
        String schema = Optional.ofNullable(tool.getAttributes()).map(x -> StringUtils.gson.toJson(x.get("input_schema"))).orElse(Optional.ofNullable(actualTool.getAttributes()).map(x -> (String)x.get("input_schema")).orElse("{}"));
        String description = Optional.ofNullable(tool.getDescription()).orElse(factory.getDefaultDescription());
        return new McpServerFeatures.AsyncToolSpecification(new McpSchema.Tool(toolName, String.valueOf(description), schema), (exchange, arguments) -> Mono.create(sink -> {
            ActionListener actionListener = ActionListener.wrap(r -> sink.success((Object)new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(r)), Boolean.valueOf(false))), e -> {
                log.error("Failed to execute tool, tool name: {}", (Object)toolName, e);
                sink.error((Throwable)e);
            });
            actualTool.run(StringUtils.getParameterMap((Map)arguments), actionListener);
        }));
    }

    public void searchToolsWithVersion(List<String> toolNames, ActionListener<List<McpToolRegisterInput>> listener) {
        ActionListener<SearchResponse> actionListener = this.createSearchResponseListener(listener);
        SearchRequest searchRequest = this.buildSearchRequest(toolNames);
        searchRequest.source().version(Boolean.valueOf(true));
        this.client.search(searchRequest, actionListener);
    }

    public void searchToolsWithPrimaryTermAndSeqNo(List<String> toolNames, ActionListener<SearchResponse> listener) {
        SearchRequest searchRequest = this.buildSearchRequest(toolNames);
        searchRequest.source().seqNoAndPrimaryTerm(Boolean.valueOf(true));
        this.client.search(searchRequest, listener);
    }

    private SearchRequest buildSearchRequest(List<String> toolNames) {
        SearchRequest searchRequest = new SearchRequest();
        searchRequest.indices(new String[]{MLIndex.MCP_TOOLS.getIndexName()});
        BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();
        toolNames.forEach(toolName -> queryBuilder.should((QueryBuilder)QueryBuilders.matchQuery((String)"name", (Object)toolName)));
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)queryBuilder);
        searchRequest.source(searchSourceBuilder);
        return searchRequest;
    }

    public void searchAllToolsWithVersion(ActionListener<Map<String, Tuple<McpToolRegisterInput, Long>>> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener actionListener = ActionListener.wrap(r -> {
                HashMap mcpTools = new HashMap();
                Arrays.stream(Objects.requireNonNull(r.getHits().getHits())).forEach(x -> {
                    long version = x.getVersion();
                    try {
                        McpToolRegisterInput mcpTool = this.parseMcpTool(x.getSourceAsString());
                        mcpTools.put(mcpTool.getName(), Tuple.tuple((Object)mcpTool, (Object)version));
                    }
                    catch (IOException e) {
                        restoreListener.onFailure((Exception)e);
                    }
                });
                restoreListener.onResponse(mcpTools);
            }, e -> {
                String errMsg = String.format(Locale.ROOT, "Failed to search mcp tools index with error: %s", e.getMessage());
                log.error(errMsg, (Throwable)e);
                restoreListener.onFailure((Exception)new OpenSearchException(errMsg, new Object[0]));
            });
            this.client.search(this.buildSearchRequest(), actionListener);
        }
        catch (Exception e2) {
            log.error("Failed to search mcp tools index", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    public void searchAllTools(ActionListener<List<McpToolRegisterInput>> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener actionListener = ActionListener.wrap(r -> {
                ArrayList mcpTools = new ArrayList();
                Arrays.stream(Objects.requireNonNull(r.getHits().getHits())).forEach(x -> {
                    try {
                        McpToolRegisterInput mcpTool = this.parseMcpTool(x.getSourceAsString());
                        mcpTools.add(mcpTool);
                    }
                    catch (IOException e) {
                        listener.onFailure((Exception)e);
                    }
                });
                restoreListener.onResponse(mcpTools);
            }, e -> {
                String errMsg = String.format(Locale.ROOT, "Failed to search mcp tools index with error: %s", e.getMessage());
                log.error(errMsg, (Throwable)e);
                restoreListener.onFailure((Exception)new OpenSearchException(errMsg, new Object[0]));
            });
            this.client.search(this.buildSearchRequest(), actionListener);
        }
        catch (Exception e2) {
            log.error("Failed to search mcp tools index", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private SearchRequest buildSearchRequest() {
        SearchRequest searchRequest = new SearchRequest();
        searchRequest.indices(new String[]{MLIndex.MCP_TOOLS.getIndexName()});
        MatchAllQueryBuilder queryBuilder = QueryBuilders.matchAllQuery();
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.version(Boolean.valueOf(true));
        searchSourceBuilder.query((QueryBuilder)queryBuilder);
        searchRequest.source(searchSourceBuilder);
        searchRequest.source().size(1000);
        return searchRequest;
    }

    private ActionListener<SearchResponse> createSearchResponseListener(ActionListener<List<McpToolRegisterInput>> listener) {
        return ActionListener.wrap(r -> {
            ArrayList mcpTools = new ArrayList();
            Arrays.stream(Objects.requireNonNull(r.getHits().getHits())).forEach(x -> {
                try {
                    McpToolRegisterInput mcpTool = this.parseMcpTool(x.getSourceAsString());
                    mcpTools.add(mcpTool);
                }
                catch (IOException e) {
                    listener.onFailure((Exception)e);
                }
            });
            listener.onResponse(mcpTools);
        }, e -> {
            String errMsg = String.format(Locale.ROOT, "Failed to search mcp tools index with error: %s", e.getMessage());
            log.error(errMsg, (Throwable)e);
            listener.onFailure((Exception)new OpenSearchException(errMsg, new Object[0]));
        });
    }

    private McpToolRegisterInput parseMcpTool(String input) throws IOException {
        McpToolRegisterInput mcpToolRegisterInput;
        block8: {
            XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, input);
            try {
                XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                mcpToolRegisterInput = McpToolRegisterInput.parse((XContentParser)parser);
                if (parser == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (parser != null) {
                        try {
                            parser.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    log.error("Failed to parse mcp tools configuration: {}", (Object)input);
                    throw e;
                }
            }
            parser.close();
        }
        return mcpToolRegisterInput;
    }
}

