Source code for gnes.base

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


import inspect
import os
import pickle
import tempfile
import uuid
from functools import wraps
from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable

import ruamel.yaml.constructor

from ..helper import set_logger, profiling, yaml, parse_arg, load_contrib_module

__all__ = ['TrainableBase', 'CompositionalTrainableBase']

T = TypeVar('T', bound='TrainableBase')


def register_all_class(cls2file_map: Dict, module_name: str):
    import importlib
    for k, v in cls2file_map.items():
        try:
            getattr(importlib.import_module('gnes.%s.%s' % (module_name, v)), k)
        except ImportError as ex:
            default_logger = set_logger('GNES')
            default_logger.warning('fail to register %s, due to "%s", you will not be able to use this model' % (k, ex))
    load_contrib_module()


def import_class_by_str(name: str):
    def _import(module_name, class_name):
        import importlib

        cls2file = getattr(importlib.import_module('gnes.%s' % module_name), '_cls2file_map')
        if class_name in cls2file:
            return getattr(importlib.import_module('gnes.%s.%s' % (module_name, cls2file[class_name])), class_name)

    search_modules = ['encoder', 'indexer', 'preprocessor', 'router', 'score_fn']

    for m in search_modules:
        r = _import(m, name)
        if r:
            return r
    else:
        raise ImportError('Can not locate any class with name: %s, misspelling?' % name)


class TrainableType(type):
    default_gnes_config = {
        'is_trained': False,
        'batch_size': None,
        'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
        'name': None,
        'on_gpu': False,
        'warn_unnamed': True
    }

    def __new__(cls, *args, **kwargs):
        _cls = super().__new__(cls, *args, **kwargs)
        return cls.register_class(_cls)

    def __call__(cls, *args, **kwargs):
        # do _preload_package
        getattr(cls, '_pre_init', lambda *x: None)()

        if 'gnes_config' in kwargs:
            gnes_config = kwargs.pop('gnes_config')
        else:
            gnes_config = {}

        obj = type.__call__(cls, *args, **kwargs)

        # set attribute with priority
        # gnes_config in YAML > class attribute > default_gnes_config
        for k, v in TrainableType.default_gnes_config.items():
            if k in gnes_config:
                v = gnes_config[k]
            v = _expand_env_var(v)
            if not hasattr(obj, k):
                if k == 'is_trained' and isinstance(obj, CompositionalTrainableBase):
                    continue
                setattr(obj, k, v)

        getattr(obj, '_post_init_wrapper', lambda *x: None)()
        return obj

    @staticmethod
    def register_class(cls):
        # print('try to register class: %s' % cls.__name__)
        reg_cls_set = getattr(cls, '_registered_class', set())
        if cls.__name__ not in reg_cls_set:
            # print('reg class: %s' % cls.__name__)
            cls.__init__ = TrainableType._store_init_kwargs(cls.__init__)
            if os.environ.get('GNES_PROFILING', False):
                for f_name in ['train', 'encode', 'add', 'query', 'index']:
                    if getattr(cls, f_name, None):
                        setattr(cls, f_name, profiling(getattr(cls, f_name)))

            if getattr(cls, 'train', None):
                # print('registered train func of %s'%cls)
                setattr(cls, 'train', TrainableType._as_train_func(getattr(cls, 'train')))

            reg_cls_set.add(cls.__name__)
            setattr(cls, '_registered_class', reg_cls_set)
        yaml.register_class(cls)
        return cls

    @staticmethod
    def _as_train_func(func):
        @wraps(func)
        def arg_wrapper(self, *args, **kwargs):
            if self.is_trained:
                self.logger.warning('"%s" has been trained already, '
                                    'training it again will override the previous training' % self.__class__.__name__)
            f = func(self, *args, **kwargs)
            if not isinstance(self, CompositionalTrainableBase):
                self.is_trained = True
            return f

        return arg_wrapper

    @staticmethod
    def _store_init_kwargs(func):
        @wraps(func)
        def arg_wrapper(self, *args, **kwargs):
            taboo = {'self', 'args', 'kwargs'}
            taboo.update(TrainableType.default_gnes_config.keys())
            all_pars = inspect.signature(func).parameters
            tmp = {k: v.default for k, v in all_pars.items() if k not in taboo}
            tmp_list = [k for k in all_pars.keys() if k not in taboo]
            # set args by aligning tmp_list with arg values
            for k, v in zip(tmp_list, args):
                tmp[k] = v
            # set kwargs
            for k, v in kwargs.items():
                if k in tmp:
                    tmp[k] = v

            if self.store_args_kwargs:
                if args: tmp['args'] = args
                if kwargs: tmp['kwargs'] = {k: v for k, v in kwargs.items() if k not in taboo}

            if getattr(self, '_init_kwargs_dict', None):
                self._init_kwargs_dict.update(tmp)
            else:
                self._init_kwargs_dict = tmp
            f = func(self, *args, **kwargs)
            return f

        return arg_wrapper


[docs]class TrainableBase(metaclass=TrainableType): """ The base class for preprocessor, encoder, indexer and router """ store_args_kwargs = False def __init__(self, *args, **kwargs): self.verbose = 'verbose' in kwargs and kwargs['verbose'] self.logger = set_logger(self.__class__.__name__, self.verbose) self._post_init_vars = set() def _post_init_wrapper(self): if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1': _id = str(uuid.uuid4()).split('-')[0] _name = '%s-%s' % (self.__class__.__name__, _id) if self.warn_unnamed: self.logger.warning( 'this object is not named ("name" is not found under "gnes_config" in YAML config), ' 'i will call it "%s". ' 'naming the object is important as it provides an unique identifier when ' 'serializing/deserializing this object.' % _name) setattr(self, 'name', _name) _before = set(list(self.__dict__.keys())) self.post_init() self._post_init_vars = {k for k in self.__dict__ if k not in _before}
[docs] def post_init(self): """ Declare class attributes/members that can not be serialized in standard way """ pass
[docs] @classmethod def pre_init(cls): pass
@property def dump_full_path(self): """ Get the binary dump path :return: """ return os.path.join(self.work_dir, '%s.bin' % self.name) @property def yaml_full_path(self): """ Get the file path of the yaml config :return: """ return os.path.join(self.work_dir, '%s.yml' % self.name) def __getstate__(self): d = dict(self.__dict__) del d['logger'] for k in self._post_init_vars: del d[k] return d def __setstate__(self, d): self.__dict__.update(d) self.logger = set_logger(self.__class__.__name__, self.verbose) try: self._post_init_wrapper() except ImportError as ex: self.logger.warning('ImportError is often caused by a missing component, ' 'which often can be solved by "pip install" relevant package. %s' % ex, exc_info=True)
[docs] def train(self, *args, **kwargs): """ Train the model, need to be overrided """ pass
[docs] @profiling def dump(self, filename: str = None) -> None: """ Serialize the object to a binary file :param filename: file path of the serialized file, if not given then :py:attr:`dump_full_path` is used """ f = filename or self.dump_full_path if not f: f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('GNES_VOLUME', None)).name with open(f, 'wb') as fp: pickle.dump(self, fp) self.logger.critical('model is serialized to %s' % f)
[docs] @profiling def dump_yaml(self, filename: str = None) -> None: """ Serialize the object to a yaml file :param filename: file path of the yaml file, if not given then :py:attr:`dump_yaml_path` is used """ f = filename or self.yaml_full_path if not f: f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('GNES_VOLUME', None)).name with open(f, 'w', encoding='utf8') as fp: yaml.dump(self, fp) self.logger.info('model\'s yaml config is dump to %s' % f)
[docs] @classmethod def load_yaml(cls: Type[T], filename: Union[str, TextIO]) -> T: if not filename: raise FileNotFoundError if isinstance(filename, str): with open(filename, encoding='utf8') as fp: return yaml.load(fp) else: with filename: return yaml.load(filename)
[docs] @staticmethod @profiling def load(filename: str = None) -> T: if not filename: raise FileNotFoundError with open(filename, 'rb') as fp: return pickle.load(fp)
[docs] def close(self): """ Release the resources as model is destroyed """ pass
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() @staticmethod def _get_tags_from_node(node): def node_recurse_generator(n): if n.tag.startswith('!'): yield n.tag.lstrip('!') for nn in n.value: if isinstance(nn, tuple): for k in nn: yield from node_recurse_generator(k) elif isinstance(nn, ruamel.yaml.nodes.Node): yield from node_recurse_generator(nn) return list(set(list(node_recurse_generator(node))))
[docs] @classmethod def to_yaml(cls, representer, data): tmp = data._dump_instance_to_yaml(data) return representer.represent_mapping('!' + cls.__name__, tmp)
[docs] @classmethod def from_yaml(cls, constructor, node, stop_on_import_error=False): return cls._get_instance_from_yaml(constructor, node, stop_on_import_error)[0]
@classmethod def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False): try: for c in cls._get_tags_from_node(node): import_class_by_str(c) except ImportError as ex: if stop_on_import_error: raise RuntimeError('Cannot import module, pip install may required') from ex if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}: os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0' data = ruamel.yaml.constructor.SafeConstructor.construct_mapping( constructor, node, deep=True) _gnes_config = data.get('gnes_config', {}) for k, v in _gnes_config.items(): _gnes_config[k] = _expand_env_var(v) if _gnes_config: data['gnes_config'] = _gnes_config dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {})) load_from_dump = False if dump_path: obj = cls.load(dump_path) obj.logger.critical('restore %s from %s' % (cls.__name__, dump_path)) load_from_dump = True else: cls.init_from_yaml = True if cls.store_args_kwargs: p = data.get('parameters', {}) # type: Dict[str, Any] a = p.pop('args') if 'args' in p else () k = p.pop('kwargs') if 'kwargs' in p else {} # maybe there are some hanging kwargs in "parameters" tmp_a = (_expand_env_var(v) for v in a) tmp_p = {kk: _expand_env_var(vv) for kk, vv in {**k, **p}.items()} obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {})) else: tmp_p = {kk: _expand_env_var(vv) for kk, vv in data.get('parameters', {}).items()} obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {})) obj.logger.critical('initialize %s from a yaml config' % cls.__name__) cls.init_from_yaml = False if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}: os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1' return obj, data, load_from_dump @staticmethod def _get_dump_path_from_config(gnes_config: Dict): if 'name' in gnes_config: dump_path = os.path.join(gnes_config.get('work_dir', os.getcwd()), '%s.bin' % gnes_config['name']) if os.path.exists(dump_path): return dump_path @staticmethod def _dump_instance_to_yaml(data): # note: we only dump non-default property for the sake of clarity p = {k: getattr(data, k) for k, v in TrainableType.default_gnes_config.items() if getattr(data, k) != v} a = {k: v for k, v in data._init_kwargs_dict.items() if k not in TrainableType.default_gnes_config} r = {} if a: r['parameters'] = a if p: r['gnes_config'] = p return r def _copy_from(self, x: 'TrainableBase') -> None: pass
[docs]class CompositionalTrainableBase(TrainableBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._components = None # type: List[T] @property def is_trained(self): return self.components and all(c.is_trained for c in self.components) @property def components(self) -> Union[List[T], Dict[str, T]]: return self._components @property def is_pipeline(self): return isinstance(self.components, list) @components.setter def components(self, comps: Callable[[], Union[list, dict]]): if not callable(comps): raise TypeError('components must be a callable function that returns ' 'a List[BaseEncoder]') if not getattr(self, 'init_from_yaml', False): self._components = comps() else: self.logger.info('components is omitted from construction, ' 'as it is initialized from yaml config')
[docs] def close(self): super().close() # pipeline if isinstance(self.components, list): for be in self.components: be.close() # no typology elif isinstance(self.components, dict): for be in self.components.values(): be.close() elif self.components is None: pass else: raise TypeError('components must be dict or list, received %s' % type(self.components))
def _copy_from(self, x: T): if isinstance(self.components, list): for be1, be2 in zip(self.components, x.components): be1._copy_from(be2) elif isinstance(self.components, dict): for k, v in self.components.items(): v._copy_from(x.components[k]) else: raise TypeError('components must be dict or list, received %s' % type(self.components))
[docs] @classmethod def to_yaml(cls, representer, data): tmp = super()._dump_instance_to_yaml(data) tmp['components'] = data.components return representer.represent_mapping('!' + cls.__name__, tmp)
[docs] @classmethod def from_yaml(cls, constructor, node): obj, data, from_dump = super()._get_instance_from_yaml(constructor, node) if not from_dump and 'components' in data: obj.components = lambda: data['components'] return obj
def _expand_env_var(v: str) -> str: if isinstance(v, str): return parse_arg(os.path.expandvars(v)) else: return v