Skip to main content

plugins.hf_seq_classification.implementation

HfSeqClassificationPlugin Objects#

class HfSeqClassificationPlugin(Plugin)

Plugin for Text Sequence Classification using Huggingface models.

plugin.setup() bootstraps the entire pipeline and returns a fully setup trainer. Example::

trainer = plugin.setup() trainer.train() trainer.validate()

Alternatively, you can run setup_datainterface setup_module setup_trainer individually. Example::

plugin.setup_datainterface() plugin.setup_module() trainer = plugin.setup_trainer()

__init__#

def __init__(config: Optional[Dict] = None)

CustomArgParser parses YAML config located at cmdline --config_path. If --config_path is not provided, assumes YAML file is named config.yaml and present in working directory. Instantiates dataclasses: self.data_args (arguments.DataInterfaceArguments): Instantiated dataclass containing args. self.module_args (arguments.ModuleInterfaceArguments): Instantiated dataclass containing args required to initialize HfSeqClassificationModule class. self.distill_args (arguments.DistillationArguments): Instantiated dataclass required to initialize DistillHfModule. Set self.distill_args.enable = True in config file to do knowledge distillation instead of regular training. Sets properties: self.datainterface: data_interface.DataInterface [HfSeqClassificationDataInterface] object self.module: module_interface.ModuleInterface [HfSeqClassificationModule] object This is used to initialize a Marlin trainer.

setup_datainterface#

def setup_datainterface()

Calls datainterface.setup_datasets(train_data, val_data).

Assumptions: Training and validation files are placed in separate directories. Accepted file formats: source/target text lines in data_args.data_dir/{train,val}.{source,targets}

setup_module#

def setup_module()

Sets HfSeqClassificationModule.data property to datainterface which contains the processed datasets. Assertion error is thrown if datainterface retrieves no train or val data, indicating that datainterface hasn't been setup with processed data. Sets the HfSeqClassificationModule.model property after initializing weights: Option 1: Load weights from specified files mentioned in YAML config model: model_config_path model_config_file model_path model_file Option 2: Load from Huggingface model hub, specify string in YAML config as: model: hf_model

setup#

def setup()

Executes all the setup methods required to create a trn.Trainer object. Trainer needs moduleinterface and backend is specified by self.trainer_args.backend.