/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.TriConsumer;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.dto.ExplainDTO;
import org.opensearch.neuralsearch.processor.dto.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;

public class RRFNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    @Generated
    private static final Logger log = LogManager.getLogger(RRFNormalizationTechnique.class);
    public static final String TECHNIQUE_NAME = "rrf";
    public static final int DEFAULT_RANK_CONSTANT = 60;
    public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant";
    private static final Set<String> SUPPORTED_PARAMS = Set.of("rank_constant");
    private static final int MIN_RANK_CONSTANT = 1;
    private static final int MAX_RANK_CONSTANT = 10000;
    private static final Range<Integer> RANK_CONSTANT_RANGE = Range.of((Comparable)Integer.valueOf(1), (Comparable)Integer.valueOf(10000));
    private final int rankConstant;
    private static final Comparator<ShardResultPerSubQuery> comparator = Comparator.comparing(sr -> sr.scoreDoc, ScoreDoc.COMPARATOR).thenComparingInt(sr -> sr.referenceShardId);

    public RRFNormalizationTechnique(Map<String, Object> params, ScoreNormalizationUtil scoreNormalizationUtil) {
        scoreNormalizationUtil.validateParameters(params, SUPPORTED_PARAMS, Map.of());
        this.rankConstant = this.getRankConstant(params);
    }

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        Map<Integer, Map<String, Integer>> sortedDocIdsPerSubqueryByGlobalRank = normalizeScoresDTO.isSingleShard() ? Map.of() : this.sortDocumentsAsPerGlobalRankInIndividualQuery(queryTopDocs);
        for (int referenceShardId = 0; referenceShardId < queryTopDocs.size(); ++referenceShardId) {
            this.processTopDocs(queryTopDocs.get(referenceShardId), (TriConsumer<DocIdAtSearchShard, Float, Integer>)((TriConsumer)(docId, score, subQueryIndex) -> {}), sortedDocIdsPerSubqueryByGlobalRank, referenceShardId);
        }
    }

    private Map<Integer, Map<String, Integer>> sortDocumentsAsPerGlobalRankInIndividualQuery(@NonNull List<CompoundTopDocs> queryTopDocs) {
        Objects.requireNonNull(queryTopDocs, "queryTopDocs is marked non-null but is null");
        HashMap<Integer, PriorityQueue<ShardResultPerSubQuery>> scoreDocsPerSubquery = new HashMap<Integer, PriorityQueue<ShardResultPerSubQuery>>();
        for (int referenceShardId = 0; referenceShardId < queryTopDocs.size(); ++referenceShardId) {
            CompoundTopDocs compoundTopDocs = queryTopDocs.get(referenceShardId);
            if (Objects.isNull(compoundTopDocs)) continue;
            List<TopDocs> topDocs = compoundTopDocs.getTopDocs();
            for (int topDocIndex = 0; topDocIndex < topDocs.size(); ++topDocIndex) {
                TopDocs topDoc = topDocs.get(topDocIndex);
                scoreDocsPerSubquery.putIfAbsent(topDocIndex, new PriorityQueue<ShardResultPerSubQuery>(comparator));
                ScoreDoc[] scoreDocs = topDoc.scoreDocs;
                for (int scoreDocIndex = 0; scoreDocIndex < scoreDocs.length; ++scoreDocIndex) {
                    ((PriorityQueue)scoreDocsPerSubquery.get(topDocIndex)).add(new ShardResultPerSubQuery(scoreDocs[scoreDocIndex], referenceShardId));
                }
            }
        }
        HashMap<Integer, Map<String, Integer>> globallySortedDocIdMap = new HashMap<Integer, Map<String, Integer>>();
        for (Map.Entry entry : scoreDocsPerSubquery.entrySet()) {
            int subQueryNumber = (Integer)entry.getKey();
            globallySortedDocIdMap.putIfAbsent(subQueryNumber, new HashMap());
            PriorityQueue sortedScoreDocsAcrossAllShards = (PriorityQueue)entry.getValue();
            int rank = 0;
            while (!sortedScoreDocsAcrossAllShards.isEmpty()) {
                ShardResultPerSubQuery shardResultPerSubQuery = (ShardResultPerSubQuery)sortedScoreDocsAcrossAllShards.poll();
                ((Map)globallySortedDocIdMap.get(subQueryNumber)).put(shardResultPerSubQuery.scoreDoc.doc + "_" + shardResultPerSubQuery.referenceShardId, rank++);
            }
        }
        return globallySortedDocIdMap;
    }

    @Override
    public String describe() {
        return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, this.rankConstant);
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(ExplainDTO explainDTO) {
        List<CompoundTopDocs> queryTopDocs = explainDTO.getQueryTopDocs();
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        Map<Integer, Map<String, Integer>> sortedDocIdsPerSubqueryByGlobalRank = explainDTO.isSingleShard() ? Map.of() : this.sortDocumentsAsPerGlobalRankInIndividualQuery(queryTopDocs);
        for (int referenceShardId = 0; referenceShardId < queryTopDocs.size(); ++referenceShardId) {
            CompoundTopDocs compoundQueryTopDocs = queryTopDocs.get(referenceShardId);
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            this.processTopDocs(compoundQueryTopDocs, (TriConsumer<DocIdAtSearchShard, Float, Integer>)((TriConsumer)(docId, score, subQueryIndex) -> ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docId, subQueryIndex, numberOfSubQueries, score.floatValue())), sortedDocIdsPerSubqueryByGlobalRank, referenceShardId);
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor, Map<Integer, Map<String, Integer>> sortedDocIdMapPerSubqueryByGlobalRank, int referenceShardId) {
        if (Objects.isNull(compoundQueryTopDocs)) {
            return;
        }
        List<TopDocs> topDocsList = compoundQueryTopDocs.getTopDocs();
        SearchShard searchShard = compoundQueryTopDocs.getSearchShard();
        for (int topDocsIndex = 0; topDocsIndex < topDocsList.size(); ++topDocsIndex) {
            Map<String, Integer> docIdToRankMap = sortedDocIdMapPerSubqueryByGlobalRank.isEmpty() ? Map.of() : sortedDocIdMapPerSubqueryByGlobalRank.get(topDocsIndex);
            this.processTopDocsEntry(topDocsList.get(topDocsIndex), searchShard, topDocsIndex, scoreProcessor, docIdToRankMap, referenceShardId);
        }
    }

    private void processTopDocsEntry(@NonNull TopDocs topDocs, SearchShard searchShard, int topDocsIndex, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor, Map<String, Integer> docIdToRankMap, int referenceShardId) {
        Objects.requireNonNull(topDocs, "topDocs is marked non-null but is null");
        for (int position = 0; position < topDocs.scoreDocs.length; ++position) {
            ScoreDoc scoreDoc = topDocs.scoreDocs[position];
            Integer rank = docIdToRankMap.isEmpty() ? position : docIdToRankMap.get(scoreDoc.doc + "_" + referenceShardId);
            if (Objects.isNull(rank) || rank < 0) {
                throw new IllegalStateException("Document not found in global ranking map: doc=" + scoreDoc.doc + ", shard=" + referenceShardId);
            }
            float normalizedScore = this.calculateNormalizedScore(rank);
            DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, searchShard);
            scoreProcessor.apply((Object)docIdAtSearchShard, (Object)Float.valueOf(normalizedScore), (Object)topDocsIndex);
            scoreDoc.score = normalizedScore;
        }
    }

    private float calculateNormalizedScore(int position) {
        return BigDecimal.ONE.divide(BigDecimal.valueOf(this.rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue();
    }

    private int getRankConstant(Map<String, Object> params) {
        if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) {
            return 60;
        }
        int rankConstant = RRFNormalizationTechnique.getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT);
        this.validateRankConstant(rankConstant);
        return rankConstant;
    }

    private void validateRankConstant(int rankConstant) {
        if (!RANK_CONSTANT_RANGE.contains((Object)rankConstant)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", rankConstant));
        }
    }

    private static int getParamAsInteger(Map<String, Object> parameters, String fieldName) {
        try {
            return NumberUtils.createInteger((String)String.valueOf(parameters.get(fieldName)));
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName));
        }
    }

    @Generated
    public String toString() {
        return "RRFNormalizationTechnique(TECHNIQUE_NAME=rrf, rankConstant=" + this.rankConstant + ")";
    }

    private record ShardResultPerSubQuery(ScoreDoc scoreDoc, int referenceShardId) {
    }
}

