Source code for gnes.score_fn.doc

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from .base import get_unary_score, CombinedScoreFn
import json


[docs]class WeightedDocScoreFn(CombinedScoreFn): def __call__(self, last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', doc: 'gnes_pb2.Document', *args, **kwargs): d_weight = get_unary_score(value=doc.weight, name='doc weight', doc_id=doc.doc_id) return super().__call__(last_score, d_weight)
[docs]class CoordDocScoreFn(CombinedScoreFn): """ score = score * query_coordination query_coordination: #chunks recalled / #chunks in this doc """ def __call__(self, last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', doc: 'gnes_pb2.Document', *args, **kwargs): total_chunks = len(doc.chunks) recall_chunks = len(json.loads(last_score.explained)['operands']) query_coord = 1 if total_chunks == 0 else recall_chunks / total_chunks d_weight = get_unary_score(value=query_coord, name='query coordination') return super().__call__(last_score, d_weight)