gnes.encoder.text.gpt module

class gnes.encoder.text.gpt.GPT2Encoder(model_dir: str, use_cuda: bool = False, pooling_strategy: str = 'REDUCE_MEAN', *args, **kwargs)[source]

Bases: gnes.encoder.text.gpt.GPTEncoder

train(*args, **kwargs)

Train the model, need to be overrided

class gnes.encoder.text.gpt.GPTEncoder(model_dir: str, use_cuda: bool = False, pooling_strategy: str = 'REDUCE_MEAN', *args, **kwargs)[source]

Bases: gnes.encoder.base.BaseTextEncoder

batch_size = 64
encode(text: List[str], *args, **kwargs) → numpy.ndarray[source]
is_trained = True
post_init()[source]

Declare class attributes/members that can not be serialized in standard way

train(*args, **kwargs)

Train the model, need to be overrided