Source code for gnes.preprocessor.base

#  Tencent is pleased to support the open source community by making GNES available.
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  See the License for the specific language governing permissions and
#  limitations under the License.

import io

import numpy as np

from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, array2blob

[docs]class BasePreprocessor(TrainableBase): doc_type = gnes_pb2.Document.UNKNOWN def __init__(self, uniform_doc_weight: bool = True, *args, **kwargs): super().__init__(*args, **kwargs) self.uniform_doc_weight = uniform_doc_weight
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: doc.doc_type = self.doc_type if not doc.weight and self.uniform_doc_weight: doc.weight = 1.0
[docs]class BaseTextPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.TEXT
[docs]class BaseAudioPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.AUDIO
[docs]class BaseImagePreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.IMAGE
[docs]class BaseVideoPreprocessor(BasePreprocessor): doc_type = gnes_pb2.Document.VIDEO
[docs]class PipelinePreprocessor(CompositionalTrainableBase):
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: if not self.components: raise NotImplementedError for be in self.components: be.apply(doc)
[docs] def train(self, data, *args, **kwargs): if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): be.train(data, *args, **kwargs) if idx + 1 < len(self.components): data = be.apply(data, *args, **kwargs)
[docs]class UnaryPreprocessor(BasePreprocessor): is_trained = True def __init__(self, doc_type: int, *args, **kwargs): super().__init__(*args, **kwargs) self.doc_type = doc_type
[docs] def apply(self, doc: 'gnes_pb2.Document'): super().apply(doc) c = doc.chunks.add() c.doc_id = doc.doc_id c.offset = 0 c.weight = 1. if doc.raw_bytes: self.raw_to_chunk(c, doc.raw_bytes) else: self.logger.error('bad document: "raw_bytes" is empty!')
[docs] def raw_to_chunk(self, chunk: 'gnes_pb2.Chunk', raw_bytes: bytes): if self.doc_type == gnes_pb2.Document.TEXT: chunk.text = raw_bytes.decode() elif self.doc_type == gnes_pb2.Document.IMAGE: from PIL import Image img = np.array( chunk.blob.CopyFrom(array2blob(img)) elif self.doc_type == gnes_pb2.Document.VIDEO: raise NotImplementedError else: raise NotImplementedError
[docs]class RawChunkPreprocessor(BasePreprocessor): @staticmethod def _parse_chunk(chunk: 'gnes_pb2.Chunk', *args, **kwargs): raise NotImplementedError
[docs] def apply(self, doc: 'gnes_pb2.Document') -> None: if doc.raw_bytes: for chunk in doc.chunks: chunk.raw = self._parse_chunk(chunk) else: self.logger.error('bad document: "raw_bytes" is empty!')