Skip to main content

plugins.hf_seq2seq.implementation

HfSeq2SeqPlugin Objects#

class HfSeq2SeqPlugin(Plugin)

Plugin for Text Sequence to Sequence Generation using Huggingface models.

plugin.setup() bootstraps the entire pipeline and returns a fully setup trainer.

Example:

trainer = plugin.setup() trainer.train() trainer.validate()

Alternatively, you can run setup_datainterface setup_module setup_trainer individually.

Example:

plugin.setup_datainterface() plugin.setup_module() trainer = plugin.setup_trainer()

__init__#

def __init__(config=None)

Accepts optional config dictionary. CustomArgParser parses YAML config located at cmdline --config_path. If --config_path is not provided, assumes YAML file is named config.yaml and present in working directory. Instantiates dataclasses: self.data_args (arguments.DataInterfaceArguments): Data Inference arguments self.module_args (arguments.ModuleInterfaceArguments): Module Interface Arguments Sets properties: self.datainterface: data_interface.DataInterface [HfSeq2SeqData] object self.moduleinterface: module_interface.ModuleInterface [HfSeq2SeqModule] object

setup_datainterface#

def setup_datainterface()

Calls datainterface.setup_datasets(train_data, val_data).

Assumptions: Training and validation files are placed in separate directories. Accepted file formats: source/target text lines in data_args.data_dir/{train,val}.{source,targets}

setup_module#

def setup_module()

Sets HfSeq2SeqModule.data property to datainterface which contains the processed datasets. Assertion error is thrown if datainterface retrieves no train or val data, indicating that datainterface hasn't been setup with processed data. Sets the HfSeq2SeqModule.model property after initializing weights: Option 1: Load weights from specified files mentioned in YAML config model: model_config_path model_config_file model_path model_file Option 2: Load from Huggingface model hub, specify string in YAML config as: model: hf_model

setup#

def setup()

Executes all the setup methods required to create a trn.Trainer object. Trainer needs moduleinterface and backend is specified by self.trainer_args.backend.