Source code for gnes.encoder.base

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


from typing import List, Any, Tuple, Union

import numpy as np

from ..base import TrainableBase, CompositionalTrainableBase


[docs]class BaseEncoder(TrainableBase):
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: pass
def _copy_from(self, x: 'BaseEncoder') -> None: pass
[docs]class BaseImageEncoder(BaseEncoder):
[docs] def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: pass
[docs]class BaseVideoEncoder(BaseEncoder):
[docs] def encode(self, data: List['np.ndarray'], *args, **kwargs) -> Union[np.ndarray, List['np.ndarray']]: pass
[docs]class BaseTextEncoder(BaseEncoder):
[docs] def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]: pass
[docs]class BaseNumericEncoder(BaseEncoder): """Note that all NumericEncoder can not be used as the first encoder of the pipeline"""
[docs] def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray: pass
[docs]class BaseAudioEncoder(BaseEncoder):
[docs] def encode(self, data: List['np.ndarray'], *args, **kwargs) -> np.ndarray: pass
[docs]class BaseBinaryEncoder(BaseEncoder):
[docs] def encode(self, data: np.ndarray, *args, **kwargs) -> bytes: if data.dtype != np.uint8: raise ValueError('data must be np.uint8 but received %s' % data.dtype) return data.tobytes()
[docs]class PipelineEncoder(CompositionalTrainableBase):
[docs] def encode(self, data: Any, *args, **kwargs) -> Any: if not self.components: raise NotImplementedError for be in self.components: data = be.encode(data, *args, **kwargs) return data
[docs] def train(self, data, *args, **kwargs): if not self.components: raise NotImplementedError for idx, be in enumerate(self.components): if not be.is_trained: be.train(data, *args, **kwargs) if idx + 1 < len(self.components): data = be.encode(data, *args, **kwargs)