/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.io.PathUtils;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.query.KNNScorer;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class KNNWeight
extends Weight {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNWeight.class);
    private static ModelDao modelDao;
    private final KNNQuery knnQuery;
    private final float boost;
    private final NativeMemoryCacheManager nativeMemoryCacheManager;
    private final Weight filterWeight;

    public KNNWeight(KNNQuery query, float boost) {
        super((Query)query);
        this.knnQuery = query;
        this.boost = boost;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
        this.filterWeight = null;
    }

    public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
        super((Query)query);
        this.knnQuery = query;
        this.boost = boost;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
        this.filterWeight = filterWeight;
    }

    public static void initialize(ModelDao modelDao) {
        KNNWeight.modelDao = modelDao;
    }

    public Explanation explain(LeafReaderContext context, int doc) {
        return Explanation.match((Number)Float.valueOf(1.0f), (String)"No Explanation", (Explanation[])new Explanation[0]);
    }

    public Scorer scorer(LeafReaderContext context) throws IOException {
        int[] filterIdsArray = this.getFilterIdsArray(context);
        if (this.filterWeight != null && filterIdsArray.length == 0) {
            return KNNScorer.emptyScorer(this);
        }
        HashMap<Integer, Float> docIdsToScoreMap = new HashMap<Integer, Float>();
        if (this.filterWeight != null && filterIdsArray.length <= this.knnQuery.getK()) {
            docIdsToScoreMap.putAll(this.doExactSearch(context, filterIdsArray));
        } else {
            Map<Integer, Float> annResults = this.doANNSearch(context, filterIdsArray);
            if (annResults == null) {
                return null;
            }
            docIdsToScoreMap.putAll(annResults);
        }
        if (docIdsToScoreMap.isEmpty()) {
            return KNNScorer.emptyScorer(this);
        }
        return this.convertSearchResponseToScorer(docIdsToScoreMap);
    }

    private BitSet getFilteredDocsBitSet(LeafReaderContext ctx, Weight filterWeight) throws IOException {
        Bits liveDocs = ctx.reader().getLiveDocs();
        int maxDoc = ctx.reader().maxDoc();
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return new FixedBitSet(0);
        }
        return this.createBitSet(scorer.iterator(), liveDocs, maxDoc);
    }

    private BitSet createBitSet(DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
            return ((BitSetIterator)filteredDocIdsIterator).getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator){

            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of((DocIdSetIterator)filterIterator, (int)maxDoc);
    }

    private int[] getFilterIdsArray(LeafReaderContext context) throws IOException {
        if (this.filterWeight == null) {
            return new int[0];
        }
        BitSet filteredDocsBitSet = this.getFilteredDocsBitSet(context, this.filterWeight);
        int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
        int filteredIdsIndex = 0;
        int docId = 0;
        while (docId < filteredDocsBitSet.length() && (docId = filteredDocsBitSet.nextSetBit(docId)) != Integer.MAX_VALUE && docId + 1 != Integer.MAX_VALUE) {
            log.debug("Docs in filtered docs id set is : {}", (Object)docId);
            filteredIds[filteredIdsIndex] = docId++;
            ++filteredIdsIndex;
        }
        return filteredIds;
    }

    private Map<Integer, Float> doANNSearch(LeafReaderContext context, int[] filterIdsArray) throws IOException {
        KNNQueryResult[] results;
        NativeMemoryAllocation indexAllocation;
        SpaceType spaceType;
        KNNEngine knnEngine;
        SegmentReader reader = (SegmentReader)FilterLeafReader.unwrap((LeafReader)context.reader());
        String directory = ((FSDirectory)FilterDirectory.unwrap((Directory)reader.directory())).getDirectory().toString();
        FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(this.knnQuery.getField());
        if (fieldInfo == null) {
            log.debug("[KNN] Field info not found for {}:{}", (Object)this.knnQuery.getField(), (Object)reader.getSegmentName());
            return null;
        }
        String modelId = fieldInfo.getAttribute("model_id");
        if (modelId != null) {
            ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
            if (modelMetadata == null) {
                throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
            }
            knnEngine = modelMetadata.getKnnEngine();
            spaceType = modelMetadata.getSpaceType();
        } else {
            String engineName = fieldInfo.attributes().getOrDefault("engine", KNNEngine.NMSLIB.getName());
            knnEngine = KNNEngine.getEngine(engineName);
            String spaceTypeName = fieldInfo.attributes().getOrDefault("spaceType", SpaceType.L2.getValue());
            spaceType = SpaceType.getSpace(spaceTypeName);
        }
        Object engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() ? knnEngine.getExtension() + "c" : knnEngine.getExtension();
        String engineSuffix = this.knnQuery.getField() + (String)engineExtension;
        List engineFiles = reader.getSegmentInfo().files().stream().filter(fileName -> fileName.endsWith(engineSuffix)).collect(Collectors.toList());
        if (engineFiles.isEmpty()) {
            log.debug("[KNN] No engine index found for field {} for segment {}", (Object)this.knnQuery.getField(), (Object)reader.getSegmentName());
            return null;
        }
        Path indexPath = PathUtils.get((String)directory, (String[])new String[]{(String)engineFiles.get(0)});
        KNNCounter.GRAPH_QUERY_REQUESTS.increment();
        try {
            indexAllocation = this.nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), IndexUtil.getParametersAtLoading(spaceType, knnEngine, this.knnQuery.getIndexName()), this.knnQuery.getIndexName()), true);
        }
        catch (ExecutionException e) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e);
        }
        indexAllocation.readLock();
        try {
            if (indexAllocation.isClosed()) {
                throw new RuntimeException("Index has already been closed");
            }
            results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), this.knnQuery.getQueryVector(), this.knnQuery.getK(), knnEngine.getName(), filterIdsArray);
        }
        catch (Exception e) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e);
        }
        finally {
            indexAllocation.readUnlock();
        }
        if (results.length == 0) {
            log.debug("[KNN] Query yielded 0 results");
            return null;
        }
        return Arrays.stream(results).collect(Collectors.toMap(KNNQueryResult::getId, result -> Float.valueOf(knnEngine.score(result.getScore(), spaceType))));
    }

    private Map<Integer, Float> doExactSearch(LeafReaderContext leafReaderContext, int[] filterIdsArray) {
        SegmentReader reader = (SegmentReader)FilterLeafReader.unwrap((LeafReader)leafReaderContext.reader());
        FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(this.knnQuery.getField());
        float[] queryVector = this.knnQuery.getQueryVector();
        try {
            BinaryDocValues values = DocValues.getBinary((LeafReader)leafReaderContext.reader(), (String)fieldInfo.getName());
            SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute("spaceType"));
            HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
            ScoreDoc topDoc = (ScoreDoc)queue.top();
            HashMap<Integer, Float> docToScore = new HashMap<Integer, Float>();
            for (int filterId : filterIdsArray) {
                int docId = values.advance(filterId);
                BytesRef value = values.binaryValue();
                ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
                KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
                float[] vector = vectorSerializer.byteToFloatArray(byteStream);
                float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
                if (!(score > topDoc.score)) continue;
                topDoc.score = score;
                topDoc.doc = docId;
                topDoc = (ScoreDoc)queue.updateTop();
            }
            while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
                queue.pop();
            }
            while (queue.size() > 0) {
                ScoreDoc doc = (ScoreDoc)queue.pop();
                docToScore.put(doc.doc, Float.valueOf(doc.score));
            }
            return docToScore;
        }
        catch (Exception e) {
            log.error("Error while getting the doc values to do the k-NN Search for query : {}", (Object)this.knnQuery, (Object)e);
            return Collections.emptyMap();
        }
    }

    private Scorer convertSearchResponseToScorer(Map<Integer, Float> docsToScore) throws IOException {
        int maxDoc = Collections.max(docsToScore.keySet()) + 1;
        DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
        DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size());
        docsToScore.keySet().forEach(arg_0 -> ((DocIdSetBuilder.BulkAdder)setAdder).add(arg_0));
        DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
        return new KNNScorer(this, docIdSetIter, docsToScore, this.boost);
    }

    public boolean isCacheable(LeafReaderContext context) {
        return true;
    }

    public static float normalizeScore(float score) {
        if (score >= 0.0f) {
            return 1.0f / (1.0f + score);
        }
        return -score + 1.0f;
    }
}

