/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.collector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;

public class HybridTopScoreDocCollector
implements HybridSearchCollector {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridTopScoreDocCollector.class);
    private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0L, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
    private int docBase;
    private final HitsThresholdChecker hitsThresholdChecker;
    private TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
    private int totalHits;
    private int[] collectedHitsPerSubQuery;
    private final int numOfHits;
    private PriorityQueue<ScoreDoc>[] compoundScores;
    private float maxScore = 0.0f;

    public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
        this.numOfHits = numHits;
        this.hitsThresholdChecker = hitsThresholdChecker;
    }

    public LeafCollector getLeafCollector(LeafReaderContext context) {
        this.docBase = context.docBase;
        return new LeafCollector(){
            HybridQueryScorer compoundQueryScorer;

            public void setScorer(Scorable scorer) throws IOException {
                if (scorer instanceof HybridQueryScorer) {
                    log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores");
                    this.compoundQueryScorer = (HybridQueryScorer)scorer;
                } else {
                    this.compoundQueryScorer = this.getHybridQueryScorer(scorer);
                    if (Objects.isNull((Object)this.compoundQueryScorer)) {
                        log.error(String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer));
                    }
                }
            }

            private HybridQueryScorer getHybridQueryScorer(Scorable scorer) throws IOException {
                if (scorer == null) {
                    return null;
                }
                if (scorer instanceof HybridQueryScorer) {
                    return (HybridQueryScorer)scorer;
                }
                for (Scorable.ChildScorable childScorable : scorer.getChildren()) {
                    HybridQueryScorer hybridQueryScorer = this.getHybridQueryScorer(childScorable.child);
                    if (!Objects.nonNull((Object)hybridQueryScorer)) continue;
                    log.debug(String.format(Locale.ROOT, "found hybrid query scorer, it's child of scorer %s", childScorable.child.getClass().getSimpleName()));
                    return hybridQueryScorer;
                }
                return null;
            }

            public void collect(int doc) throws IOException {
                int i;
                if (Objects.isNull((Object)this.compoundQueryScorer)) {
                    throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query");
                }
                float[] subScoresByQuery = this.compoundQueryScorer.hybridScores();
                if (HybridTopScoreDocCollector.this.compoundScores == null) {
                    HybridTopScoreDocCollector.this.compoundScores = new PriorityQueue[subScoresByQuery.length];
                    for (i = 0; i < subScoresByQuery.length; ++i) {
                        HybridTopScoreDocCollector.this.compoundScores[i] = new HitQueue(HybridTopScoreDocCollector.this.numOfHits, false);
                    }
                    HybridTopScoreDocCollector.this.collectedHitsPerSubQuery = new int[subScoresByQuery.length];
                }
                ++HybridTopScoreDocCollector.this.totalHits;
                HybridTopScoreDocCollector.this.hitsThresholdChecker.incrementHitCount();
                for (i = 0; i < subScoresByQuery.length; ++i) {
                    float score = subScoresByQuery[i];
                    if (score == 0.0f) continue;
                    if (HybridTopScoreDocCollector.this.hitsThresholdChecker.isThresholdReached() && HybridTopScoreDocCollector.this.totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
                        log.info("hit count threshold reached: total hits={}, threshold={}, action=updating_results", (Object)HybridTopScoreDocCollector.this.totalHits, (Object)HybridTopScoreDocCollector.this.hitsThresholdChecker.getTotalHitsThreshold());
                        HybridTopScoreDocCollector.this.totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
                    }
                    int n = i;
                    HybridTopScoreDocCollector.this.collectedHitsPerSubQuery[n] = HybridTopScoreDocCollector.this.collectedHitsPerSubQuery[n] + 1;
                    PriorityQueue<ScoreDoc> pq = HybridTopScoreDocCollector.this.compoundScores[i];
                    ScoreDoc currentDoc = new ScoreDoc(doc + HybridTopScoreDocCollector.this.docBase, score);
                    HybridTopScoreDocCollector.this.maxScore = Math.max(currentDoc.score, HybridTopScoreDocCollector.this.maxScore);
                    pq.insertWithOverflow((Object)currentDoc);
                }
            }
        };
    }

    public ScoreMode scoreMode() {
        return this.hitsThresholdChecker.scoreMode();
    }

    public List<TopDocs> topDocs() {
        if (this.compoundScores == null) {
            return new ArrayList<TopDocs>();
        }
        ArrayList<TopDocs> topDocs = new ArrayList<TopDocs>();
        for (int i = 0; i < this.compoundScores.length; ++i) {
            topDocs.add(this.topDocsPerQuery(0, Math.min(this.collectedHitsPerSubQuery[i], this.compoundScores[i].size()), this.compoundScores[i], this.collectedHitsPerSubQuery[i]));
        }
        return topDocs;
    }

    private TopDocs topDocsPerQuery(int start, int howMany, PriorityQueue<ScoreDoc> pq, int totalHits) {
        if (howMany < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Number of hits requested must be greater than 0 but value was %d", howMany));
        }
        if (start < 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Expected value of starting position is between 0 and %d, got %d", howMany, start));
        }
        if (start >= howMany || howMany == 0) {
            return EMPTY_TOPDOCS;
        }
        int size = howMany - start;
        ScoreDoc[] results = new ScoreDoc[size];
        this.populateResults(results, size, pq);
        return new TopDocs(new TotalHits((long)totalHits, this.totalHitsRelation), results);
    }

    protected void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<ScoreDoc> pq) {
        for (int i = howMany - 1; i >= 0 && pq.size() > 0; --i) {
            if (i >= results.length) continue;
            results[i] = (ScoreDoc)pq.pop();
        }
    }

    @Override
    @Generated
    public int getTotalHits() {
        return this.totalHits;
    }

    @Override
    @Generated
    public float getMaxScore() {
        return this.maxScore;
    }
}

