/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.highlight;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.OpenSearchException;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.highlight.extractor.QueryTextExtractorRegistry;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceRequest;
import org.opensearch.neuralsearch.processor.highlight.SentenceHighlightingRequest;
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;

public class SemanticHighlighterEngine {
    @Generated
    private static final Logger log = LogManager.getLogger(SemanticHighlighterEngine.class);
    private static final String MODEL_ID_FIELD = "model_id";
    private static final String MODEL_INFERENCE_RESULT_KEY = "highlights";
    private static final String MODEL_INFERENCE_RESULT_START_KEY = "start";
    private static final String MODEL_INFERENCE_RESULT_END_KEY = "end";
    @NonNull
    private final MLCommonsClientAccessor mlCommonsClient;
    @NonNull
    private final QueryTextExtractorRegistry queryTextExtractorRegistry;

    public String getFieldText(FieldHighlightContext fieldContext) {
        if (fieldContext.hitContext == null || fieldContext.hitContext.sourceLookup() == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the hit", fieldContext.fieldName));
        }
        Object fieldTextObject = fieldContext.hitContext.sourceLookup().extractValue(fieldContext.fieldName, null);
        if (fieldTextObject == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is not found in the document", fieldContext.fieldName));
        }
        if (!(fieldTextObject instanceof String)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s must be a string for highlighting, but was %s", fieldContext.fieldName, fieldTextObject.getClass().getSimpleName()));
        }
        String fieldTextString = (String)fieldTextObject;
        if (fieldTextString.isEmpty()) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is empty", fieldContext.fieldName));
        }
        return fieldTextString;
    }

    public String extractOriginalQuery(Query query, String fieldName) {
        if (fieldName == null) {
            log.warn("Field name is null, extraction may be less accurate");
        }
        return this.queryTextExtractorRegistry.extractQueryText(query, fieldName);
    }

    public String getModelId(Map<String, Object> options) {
        Object modelId = options.get(MODEL_ID_FIELD);
        if (Objects.isNull(modelId) || !(modelId instanceof String)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a non-null string, but was %s", MODEL_ID_FIELD, modelId == null ? "null" : modelId.getClass().getSimpleName()));
        }
        return (String)modelId;
    }

    public String getHighlightedSentences(String modelId, String question, String context, String preTag, String postTag) {
        List<Map<String, Object>> results = this.fetchModelResults(modelId, question, context);
        if (results == null || results.isEmpty()) {
            return null;
        }
        return this.applyHighlighting(context, results.getFirst(), preTag, postTag);
    }

    public List<Map<String, Object>> fetchModelResults(String modelId, String question, String context) {
        PlainActionFuture future = PlainActionFuture.newFuture();
        InferenceRequest request = ((SentenceHighlightingRequest.SentenceHighlightingRequestBuilder)((SentenceHighlightingRequest.SentenceHighlightingRequestBuilder)((SentenceHighlightingRequest.SentenceHighlightingRequestBuilder)SentenceHighlightingRequest.builder().modelId(modelId)).question(question)).context(context)).build();
        this.mlCommonsClient.inferenceSentenceHighlighting((SentenceHighlightingRequest)request, (ActionListener<List<Map<String, Object>>>)future);
        try {
            return (List)future.actionGet();
        }
        catch (Exception e) {
            log.error("Error during sentence highlighting inference - modelId: [{}], question: [{}], context: [{}]", (Object)modelId, (Object)question, (Object)context, (Object)e);
            throw new OpenSearchException(String.format(Locale.ROOT, "Error during sentence highlighting inference from model [%s]", modelId), (Throwable)e, new Object[0]);
        }
    }

    public String applyHighlighting(String context, Map<String, Object> highlightResult, String preTag, String postTag) {
        Object highlightsObj = highlightResult.get(MODEL_INFERENCE_RESULT_KEY);
        if (!(highlightsObj instanceof List)) {
            log.error(String.format(Locale.ROOT, "No valid highlights found in model inference result, highlightsObj: %s", highlightsObj));
            return null;
        }
        List highlightsList = (List)highlightsObj;
        if (highlightsList.isEmpty()) {
            return context;
        }
        ArrayList<Integer> validHighlights = new ArrayList<Integer>(highlightsList.size() * 2);
        for (Object item : highlightsList) {
            Map<String, Number> map = SemanticHighlighterEngine.getHighlightsPositionMap(item);
            Number start = map.get(MODEL_INFERENCE_RESULT_START_KEY);
            Number end = map.get(MODEL_INFERENCE_RESULT_END_KEY);
            if (start == null || end == null) {
                throw new OpenSearchException("Missing start or end position in highlight data", new Object[0]);
            }
            this.validateHighlightPositions(start.intValue(), end.intValue(), context.length());
            validHighlights.add(start.intValue());
            validHighlights.add(end.intValue());
        }
        for (int i = 2; i < validHighlights.size(); i += 2) {
            if ((Integer)validHighlights.get(i) >= validHighlights.get(i - 2)) continue;
            log.error(String.format(Locale.ROOT, "Highlights are not sorted: %s", validHighlights));
            throw new OpenSearchException("Internal error while applying semantic highlight: received unsorted highlights from model", new Object[0]);
        }
        return this.constructHighlightedText(context, validHighlights, preTag, postTag);
    }

    private void validateHighlightPositions(int start, int end, int textLength) {
        if (start < 0 || end > textLength || start >= end) {
            throw new OpenSearchException(String.format(Locale.ROOT, "Invalid highlight positions: start=%d, end=%d, textLength=%d. Positions must satisfy: 0 <= start < end <= textLength", start, end, textLength), new Object[0]);
        }
    }

    private String constructHighlightedText(String text, List<Integer> highlights, String preTag, String postTag) {
        StringBuilder result = new StringBuilder();
        int currentPos = 0;
        for (int i = 0; i < highlights.size(); i += 2) {
            int start = highlights.get(i);
            int end = highlights.get(i + 1);
            if (start > currentPos) {
                result.append(text, currentPos, start);
            }
            result.append(preTag);
            result.append(text, start, end);
            result.append(postTag);
            currentPos = end;
        }
        if (currentPos < text.length()) {
            result.append(text, currentPos, text.length());
        }
        return result.toString();
    }

    private static Map<String, Number> getHighlightsPositionMap(Object item) {
        try {
            return (Map)item;
        }
        catch (ClassCastException e) {
            throw new OpenSearchException(String.format(Locale.ROOT, "Expect item to be map of string to number, but was: %s", item), new Object[0]);
        }
    }

    @Generated
    SemanticHighlighterEngine(@NonNull MLCommonsClientAccessor mlCommonsClient, @NonNull QueryTextExtractorRegistry queryTextExtractorRegistry) {
        Objects.requireNonNull(mlCommonsClient, "mlCommonsClient is marked non-null but is null");
        Objects.requireNonNull(queryTextExtractorRegistry, "queryTextExtractorRegistry is marked non-null but is null");
        this.mlCommonsClient = mlCommonsClient;
        this.queryTextExtractorRegistry = queryTextExtractorRegistry;
    }

    @Generated
    public static SemanticHighlighterEngineBuilder builder() {
        return new SemanticHighlighterEngineBuilder();
    }

    @Generated
    public static class SemanticHighlighterEngineBuilder {
        @Generated
        private MLCommonsClientAccessor mlCommonsClient;
        @Generated
        private QueryTextExtractorRegistry queryTextExtractorRegistry;

        @Generated
        SemanticHighlighterEngineBuilder() {
        }

        @Generated
        public SemanticHighlighterEngineBuilder mlCommonsClient(@NonNull MLCommonsClientAccessor mlCommonsClient) {
            Objects.requireNonNull(mlCommonsClient, "mlCommonsClient is marked non-null but is null");
            this.mlCommonsClient = mlCommonsClient;
            return this;
        }

        @Generated
        public SemanticHighlighterEngineBuilder queryTextExtractorRegistry(@NonNull QueryTextExtractorRegistry queryTextExtractorRegistry) {
            Objects.requireNonNull(queryTextExtractorRegistry, "queryTextExtractorRegistry is marked non-null but is null");
            this.queryTextExtractorRegistry = queryTextExtractorRegistry;
            return this;
        }

        @Generated
        public SemanticHighlighterEngine build() {
            return new SemanticHighlighterEngine(this.mlCommonsClient, this.queryTextExtractorRegistry);
        }

        @Generated
        public String toString() {
            return "SemanticHighlighterEngine.SemanticHighlighterEngineBuilder(mlCommonsClient=" + String.valueOf(this.mlCommonsClient) + ", queryTextExtractorRegistry=" + String.valueOf(this.queryTextExtractorRegistry) + ")";
        }
    }
}

