/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.AbstractScoreHybridizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflowExecuteRequest;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.query.QuerySearchResult;

public class RRFProcessor
extends AbstractScoreHybridizationProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(RRFProcessor.class);
    public static final String TYPE = "score-ranker-processor";
    private final String tag;
    private final String description;
    private final ScoreNormalizationTechnique normalizationTechnique;
    private final ScoreCombinationTechnique combinationTechnique;
    private final NormalizationProcessorWorkflow normalizationWorkflow;
    private final Map<String, Runnable> combTechniqueIncrementers = Map.of("rrf", () -> EventStatsManager.increment(EventStatName.COMB_TECHNIQUE_RRF_EXECUTIONS));

    @Override
    <Result extends SearchPhaseResult> void hybridizeScores(SearchPhaseResults<Result> searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional<PipelineProcessingContext> requestContextOptional) {
        if (this.shouldSkipProcessor(searchPhaseResult)) {
            log.debug("Query results are not compatible with RRF processor");
            return;
        }
        List<QuerySearchResult> querySearchResults = this.getQueryPhaseSearchResults(searchPhaseResult);
        Optional<FetchSearchResult> fetchSearchResult = this.getFetchSearchResults(searchPhaseResult);
        boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) && searchPhaseContext.getRequest().source().explain() != false;
        this.recordStats(this.combinationTechnique);
        NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder().querySearchResults(querySearchResults).fetchSearchResultOptional(fetchSearchResult).normalizationTechnique(this.normalizationTechnique).combinationTechnique(this.combinationTechnique).explain(explain).pipelineProcessingContext(requestContextOptional.orElse(null)).searchPhaseContext(searchPhaseContext).build();
        this.normalizationWorkflow.execute(normalizationExecuteDTO);
    }

    public SearchPhaseName getBeforePhase() {
        return SearchPhaseName.QUERY;
    }

    public SearchPhaseName getAfterPhase() {
        return SearchPhaseName.FETCH;
    }

    public String getType() {
        return TYPE;
    }

    public boolean isIgnoreFailure() {
        return false;
    }

    @VisibleForTesting
    <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
        if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) {
            return true;
        }
        QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer)searchPhaseResult;
        return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
    }

    @VisibleForTesting
    boolean isHybridQuery(SearchPhaseResult searchPhaseResult) {
        return Objects.nonNull(searchPhaseResult.queryResult()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 && HybridSearchResultFormatUtil.isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
    }

    <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(SearchPhaseResults<Result> results) {
        return results.getAtomicArray().asList().stream().map(result -> result == null ? null : result.queryResult()).collect(Collectors.toList());
    }

    @VisibleForTesting
    <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(SearchPhaseResults<Result> searchPhaseResults) {
        Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
        return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
    }

    private void recordStats(ScoreCombinationTechnique combinationTechnique) {
        EventStatsManager.increment(EventStatName.RRF_PROCESSOR_EXECUTIONS);
        Optional.of(this.combTechniqueIncrementers.get(combinationTechnique.techniqueName())).ifPresent(Runnable::run);
    }

    @Generated
    public RRFProcessor(String tag, String description, ScoreNormalizationTechnique normalizationTechnique, ScoreCombinationTechnique combinationTechnique, NormalizationProcessorWorkflow normalizationWorkflow) {
        this.tag = tag;
        this.description = description;
        this.normalizationTechnique = normalizationTechnique;
        this.combinationTechnique = combinationTechnique;
        this.normalizationWorkflow = normalizationWorkflow;
    }

    @Generated
    public String getTag() {
        return this.tag;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }
}

