/*
 * Decompiled with CFR 0.152.
 */
package ghidra.features.bsim.query.file;

import generic.concurrent.ConcurrentQ;
import generic.concurrent.ConcurrentQBuilder;
import generic.concurrent.GThreadPool;
import generic.concurrent.QResult;
import generic.lsh.vector.LSHVector;
import generic.lsh.vector.VectorCompare;
import ghidra.features.bsim.query.BSimServerInfo;
import ghidra.features.bsim.query.FunctionDatabase;
import ghidra.features.bsim.query.LSHException;
import ghidra.features.bsim.query.client.AbstractSQLFunctionDatabase;
import ghidra.features.bsim.query.client.BSimSqlClause;
import ghidra.features.bsim.query.client.Configuration;
import ghidra.features.bsim.query.description.DescriptionManager;
import ghidra.features.bsim.query.description.FunctionDescription;
import ghidra.features.bsim.query.description.SignatureRecord;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.elastic.Base64VectorFactory;
import ghidra.features.bsim.query.file.BSimH2FileDBConnectionManager;
import ghidra.features.bsim.query.file.BSimVectorStoreManager;
import ghidra.features.bsim.query.file.H2VectorTable;
import ghidra.features.bsim.query.file.VectorStore;
import ghidra.features.bsim.query.file.VectorStoreEntry;
import ghidra.features.bsim.query.protocol.AdjustVectorIndex;
import ghidra.features.bsim.query.protocol.BSimQuery;
import ghidra.features.bsim.query.protocol.PasswordChange;
import ghidra.features.bsim.query.protocol.PrewarmRequest;
import ghidra.features.bsim.query.protocol.QueryNearest;
import ghidra.features.bsim.query.protocol.QueryNearestVector;
import ghidra.features.bsim.query.protocol.QueryResponseRecord;
import ghidra.features.bsim.query.protocol.ResponseNearest;
import ghidra.features.bsim.query.protocol.ResponseNearestVector;
import ghidra.features.bsim.query.protocol.SimilarityResult;
import ghidra.features.bsim.query.protocol.SimilarityVectorResult;
import ghidra.util.task.TaskMonitor;
import java.net.URL;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class H2FileFunctionDatabase
extends AbstractSQLFunctionDatabase<Base64VectorFactory> {
    public static final int OVERVIEW_FUNCS_PER_STAGE = 1024;
    public static final int QUERY_FUNCS_PER_STAGE = 256;
    private static final String H2_THREADPOOL_NAME = "H2_BSIM_THREADPOOL";
    public static final int LAYOUT_VERSION = 1;
    private final BSimH2FileDBConnectionManager.BSimH2FileDataSource fileDs;
    private H2VectorTable vectorTable;
    private VectorStore vectorStore;

    public H2FileFunctionDatabase(URL bsimURL) {
        this(BSimH2FileDBConnectionManager.getDataSource(bsimURL));
    }

    public H2FileFunctionDatabase(BSimServerInfo serverInfo) {
        this(BSimH2FileDBConnectionManager.getDataSource(serverInfo));
    }

    private H2FileFunctionDatabase(BSimH2FileDBConnectionManager.BSimH2FileDataSource ds) {
        super(ds, new Base64VectorFactory(), 1);
        this.fileDs = ds;
        this.vectorStore = BSimVectorStoreManager.getVectorStore(this.getServerInfo());
        this.vectorTable = new H2VectorTable((Base64VectorFactory)this.vectorFactory, this.vectorStore);
    }

    @Override
    public void close() {
        this.vectorTable.close();
        super.close();
    }

    @Override
    protected void setConnectionOnTables(Connection db) {
        this.vectorTable.setConnection(db);
        super.setConnectionOnTables(db);
    }

    @Override
    protected Connection initConnection() throws SQLException {
        if (this.getStatus() != FunctionDatabase.Status.Ready && !this.fileDs.exists()) {
            throw new SQLException("Database does not exist: " + this.fileDs.getServerInfo().getDBName());
        }
        return super.initConnection();
    }

    @Override
    protected void generateRawDatabase() throws SQLException {
        BSimServerInfo serverInfo = this.fileDs.getServerInfo();
        if (this.fileDs.exists()) {
            throw new SQLException("Database already exists: " + serverInfo.getDBName());
        }
        Connection c = this.fileDs.getConnection();
        if (c != null) {
            c.close();
        }
    }

    @Override
    protected void createDatabase(Configuration config) throws SQLException {
        try {
            super.createDatabase(config);
            Connection db = this.initConnection();
            try (Statement st = db.createStatement();){
                this.vectorTable.create(st);
            }
        }
        catch (SQLException err) {
            throw new SQLException("Could not create database: " + err.getMessage());
        }
    }

    public Map<Long, VectorStoreEntry> readVectorMap() throws SQLException {
        return this.vectorTable.readVectors();
    }

    @Override
    protected int deleteVectors(long id, int countdiff) throws SQLException {
        return this.vectorTable.deleteVector(id, countdiff);
    }

    @Override
    public QueryResponseRecord doQuery(BSimQuery<?> query, Connection c) throws SQLException, LSHException, FunctionDatabase.DatabaseNonFatalException {
        if (query instanceof PrewarmRequest) {
            PrewarmRequest preWarmRequest = (PrewarmRequest)query;
            preWarmRequest.buildResponseTemplate();
            preWarmRequest.prewarmresponse.operationSupported = false;
        } else if (query instanceof PasswordChange) {
            PasswordChange passwordChangeRequest = (PasswordChange)query;
            passwordChangeRequest.buildResponseTemplate();
            passwordChangeRequest.passwordResponse.changeSuccessful = false;
            passwordChangeRequest.passwordResponse.errorMessage = "Unsupported operation for H2 backend";
        } else if (query instanceof AdjustVectorIndex) {
            AdjustVectorIndex q = (AdjustVectorIndex)query;
            q.buildResponseTemplate();
            q.adjustresponse.operationSupported = false;
        } else {
            return super.doQuery(query, c);
        }
        return query.getResponse();
    }

    @Override
    protected VectorResult queryVectorId(long id) throws SQLException {
        VectorResult rowres = this.vectorTable.queryVectorById(id);
        if (rowres == null) {
            throw new SQLException("Bad vector table rowid");
        }
        return rowres;
    }

    @Override
    protected long storeSignatureRecord(SignatureRecord sigrec) throws SQLException {
        return this.vectorTable.updateVector(sigrec.getLSHVector(), 1);
    }

    @Override
    protected int queryNearestVector(List<VectorResult> resultset, LSHVector vec, double simthresh, double sigthresh, int max) throws SQLException {
        ArrayList<VectorResult> resultsToSort = new ArrayList<VectorResult>();
        for (VectorStoreEntry entry : this.vectorStore) {
            double sig;
            if (entry.selfSig() < sigthresh) continue;
            VectorCompare comp = new VectorCompare();
            vec.compare(entry.vec(), comp);
            double cosine = comp.dotproduct / (vec.getLength() * entry.vec().getLength());
            if (cosine <= simthresh || (sig = ((Base64VectorFactory)this.vectorFactory).calculateSignificance(comp)) <= sigthresh) continue;
            resultsToSort.add(new VectorResult(entry.id(), entry.count(), cosine, sig, entry.vec()));
        }
        resultsToSort.sort((r1, r2) -> Double.compare(r2.sim, r1.sim));
        int maxResults = Math.min(max, resultsToSort.size());
        for (int i = 0; i < maxResults; ++i) {
            resultset.add((VectorResult)resultsToSort.get(i));
        }
        return resultset.size();
    }

    @Override
    protected void queryNearestVector(QueryNearestVector query) throws SQLException {
        ResponseNearestVector response = query.nearresponse;
        response.totalvec = 0;
        response.totalmatch = 0;
        response.uniquematch = 0;
        int vectormax = query.vectormax == 0 ? 2000000 : query.vectormax;
        List<FunctionDescription> toQuery = this.getFuncsToQuery(query.manage, query.signifthresh);
        response.totalvec = toQuery.size();
        GThreadPool threadPool = GThreadPool.getSharedThreadPool((String)H2_THREADPOOL_NAME);
        ConcurrentQBuilder evalBuilder = new ConcurrentQBuilder();
        ConcurrentQ evalQ = evalBuilder.setThreadPool(threadPool).setCollectResults(true).setMonitor(TaskMonitor.DUMMY).build((fd, m) -> {
            ArrayList<VectorResult> resultset = new ArrayList<VectorResult>();
            this.queryNearestVector(resultset, fd.getSignatureRecord().getLSHVector(), query.thresh, query.signifthresh, vectormax);
            if (resultset.isEmpty()) {
                return null;
            }
            SimilarityVectorResult simres = new SimilarityVectorResult((FunctionDescription)fd);
            simres.addNotes(resultset);
            return simres;
        });
        evalQ.add(toQuery);
        try {
            Collection results = evalQ.waitForResults();
            for (QResult result : results) {
                SimilarityVectorResult simres = (SimilarityVectorResult)result.getResult();
                if (simres == null) continue;
                response.totalmatch += simres.getTotalCount();
                if (simres.getTotalCount() == 1) {
                    ++response.uniquematch;
                }
                response.result.add(simres);
            }
        }
        catch (Exception e) {
            return;
        }
    }

    @Override
    public int queryFunctions(QueryNearest query, BSimSqlClause filter, ResponseNearest response, DescriptionManager descMgr, Iterator<FunctionDescription> iter) throws SQLException, LSHException {
        QueryNearestVector qnv = new QueryNearestVector();
        qnv.manage.transferSettings(query.manage);
        qnv.signifthresh = query.signifthresh;
        qnv.thresh = query.thresh;
        qnv.vectormax = query.vectormax;
        qnv.buildResponseTemplate();
        HashSet<Long> distinctVecIds = new HashSet<Long>();
        while (iter.hasNext()) {
            FunctionDescription fd = iter.next();
            if (fd.getSignatureRecord() == null) continue;
            distinctVecIds.add(fd.getSignatureRecord().getLSHVector().calcUniqueHash());
            qnv.manage.transferFunction(fd, true);
        }
        this.queryNearestVector(qnv);
        ResponseNearestVector rnv = qnv.nearresponse;
        response.totalfunc = rnv.totalvec;
        for (SimilarityVectorResult simVecRes : rnv.result) {
            VectorResult dresult;
            FunctionDescription base = simVecRes.getBase();
            Iterator<VectorResult> vecResults = simVecRes.iterator();
            SimilarityResult sim = new SimilarityResult(base);
            sim.setTotalCount(simVecRes.getTotalCount());
            for (int funcsForVec = 0; vecResults.hasNext() && funcsForVec < query.max; funcsForVec += this.retrieveFuncDescFromVectors(dresult, descMgr, funcsForVec, query, filter, sim)) {
                dresult = vecResults.next();
            }
            if (sim.size() == 0) continue;
            ++response.totalmatch;
            if (sim.size() == 1) {
                ++response.uniquematch;
            }
            response.result.add(sim);
            sim.transfer(response.manage, true);
        }
        return distinctVecIds.size();
    }

    private List<FunctionDescription> getFuncsToQuery(DescriptionManager manager, double sigBound) {
        Iterator<FunctionDescription> iter = manager.listAllFunctions();
        ArrayList<FunctionDescription> toQuery = new ArrayList<FunctionDescription>();
        while (iter.hasNext()) {
            LSHVector thevec;
            double len2;
            FunctionDescription frec = iter.next();
            SignatureRecord srec = frec.getSignatureRecord();
            if (srec == null || (len2 = ((Base64VectorFactory)this.vectorFactory).getSelfSignificance(thevec = srec.getLSHVector())) < sigBound) continue;
            toQuery.add(frec);
        }
        return toQuery;
    }

    @Override
    public String formatBitAndSQL(String v1, String v2) {
        return "BITAND(" + v1 + "," + v2 + ")";
    }

    @Override
    public int getQueriedFunctionsPerStage() {
        return 256;
    }

    @Override
    public int getOverviewFunctionsPerStage() {
        return 1024;
    }
}

