Source code for gnes.indexer.vector.annoy

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import os
from typing import List, Tuple, Any

import numpy as np

from ..base import BaseVectorIndexer
from ..key_only import ListKeyIndexer


[docs]class AnnoyIndexer(BaseVectorIndexer): def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees=10, *args, **kwargs): super().__init__(*args, **kwargs) self.num_dim = num_dim self.data_path = data_path self.metric = metric self.n_trees = n_trees self._key_info_indexer = ListKeyIndexer()
[docs] def post_init(self): from annoy import AnnoyIndex self._index = AnnoyIndex(self.num_dim, self.metric) try: if not os.path.exists(self.data_path): raise FileNotFoundError('"data_path" is not exist') if os.path.isdir(self.data_path): raise IsADirectoryError('"data_path" must be a file path, not a directory') self._index.load(self.data_path) except: self.logger.warning('fail to load model from %s, will create an empty one' % self.data_path)
[docs] def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs): last_idx = self._key_info_indexer.size if len(vectors) != len(keys): raise ValueError('vectors length should be equal to doc_ids') if vectors.dtype != np.float32: raise ValueError("vectors should be ndarray of float32") for idx, vec in enumerate(vectors): self._index.add_item(last_idx + idx, vec) self._key_info_indexer.add(keys, weights)
[docs] def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> List[List[Tuple]]: self._index.build(self.n_trees) if keys.dtype != np.float32: raise ValueError('vectors should be ndarray of float32') res = [] for k in keys: ret, relevance_score = self._index.get_nns_by_vector(k, top_k, include_distances=True) relevance_score = self.normalize_score(relevance_score, self.metric) chunk_info = self._key_info_indexer.query(ret) res.append([(*r, s) for r, s in zip(chunk_info, relevance_score)]) return res
[docs] def normalize_score(self, score: List[float], metrics: str, *args, **kwargs) -> List[float]: if metrics == 'angular': return list(map(lambda x: 1 / (1 + x), score)) elif metrics == 'euclidean': import math return list(map(lambda x: 1 / (1 + math.sqrt(x) / self.num_dim), score)) elif metrics == 'manhattan': return list(map(lambda x: 1 / (1 + x / self.num_dim), score)) elif metrics == 'hamming': return list(map(lambda x: 1 / (1 + x), score)) elif metrics == 'dot': raise NotImplementedError
@property def size(self): return self._index.get_n_items() def __getstate__(self): d = super().__getstate__() self._index.save(self.data_path) return d