*Copyright (c) Microsoft Corporation. All rights reserved.*  

*Licensed under the MIT License.*

# Natural Language Inference on MultiNLI Dataset using Transformers

# Before You Start

It takes about 4 hours to fine-tune the `bert-large-cased` model on a Standard_NC24rs_v3 Azure Data Science Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 


If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [None]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

## Summary
In this notebook, we demostrate fine-tuning pretrained transformer models to perform Natural Language Inference (NLI). We use the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   
To classify a sentence pair, we concatenate the tokens in both sentences and separate the sentences by the special [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.The NLI task essentially becomes a sequence classification task. For example, the figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. 
<img src="https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG">

We compare the training time and performance of bert-large-cased and xlnet-large-cased. The model used can be set in the **Configurations** section. 

In [1]:
import sys, os
nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

import scrapbook as sb

from tempfile import TemporaryDirectory

import numpy as np
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

import torch

from utils_nlp.models.transformers.sequence_classification import Processor, SequenceClassifier
from utils_nlp.dataset.multinli import load_pandas_df
from utils_nlp.common.pytorch_utils import dataloader_from_dataset
from utils_nlp.common.timer import Timer

To see all the model supported by `SequenceClassifier`, call the `list_supported_models` method.  
**Note**: Although `SequenceClassifier` supports distilbert for single sequence classification, distilbert doesn't support sentence pair classification and can not be used in this notebook

In [None]:
SequenceClassifier.list_supported_models()

## Configurations

In [None]:
MODEL_NAME = "bert-large-cased"
TO_LOWER = False
BATCH_SIZE = 16

# MODEL_NAME = "xlnet-large-cased"
# TO_LOWER = False
# BATCH_SIZE = 16

TRAIN_DATA_USED_FRACTION = 1
DEV_DATA_USED_FRACTION = 1
NUM_EPOCHS = 2
WARMUP_STEPS= 2500

if QUICK_RUN:
    TRAIN_DATA_USED_FRACTION = 0.001
    DEV_DATA_USED_FRACTION = 0.01
    NUM_EPOCHS = 1
    WARMUP_STEPS= 10

if not torch.cuda.is_available():
    BATCH_SIZE = BATCH_SIZE/2

RANDOM_SEED = 42

# model configurations
MAX_SEQ_LENGTH = 128

# optimizer configurations
LEARNING_RATE= 5e-5

# data configurations
TEXT_COL_1 = "sentence1"
TEXT_COL_2 = "sentence2"
LABEL_COL = "gold_label"
LABEL_COL_NUM = "gold_label_num"

CACHE_DIR = TemporaryDirectory().name

## Load Data
The MultiNLI dataset comes with three subsets: train, dev_matched, dev_mismatched. The dev_matched dataset are from the same genres as the train dataset, while the dev_mismatched dataset are from genres not seen in the training dataset.   
The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split`.

In [None]:
train_df = load_pandas_df(local_cache_path=CACHE_DIR, file_split="train")
dev_df_matched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_matched")
dev_df_mismatched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_mismatched")

In [None]:
dev_df_matched = dev_df_matched.loc[dev_df_matched['gold_label'] != '-']
dev_df_mismatched = dev_df_mismatched.loc[dev_df_mismatched['gold_label'] != '-']

In [None]:
print("Training dataset size: {}".format(train_df.shape[0]))
print("Development (matched) dataset size: {}".format(dev_df_matched.shape[0]))
print("Development (mismatched) dataset size: {}".format(dev_df_mismatched.shape[0]))
print()
print(train_df[['gold_label', 'sentence1', 'sentence2']].head())

In [None]:
# sample
train_df = train_df.sample(frac=TRAIN_DATA_USED_FRACTION).reset_index(drop=True)
dev_df_matched = dev_df_matched.sample(frac=DEV_DATA_USED_FRACTION).reset_index(drop=True)
dev_df_mismatched = dev_df_mismatched.sample(frac=DEV_DATA_USED_FRACTION).reset_index(drop=True)

In [None]:
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_df[LABEL_COL])
train_df[LABEL_COL_NUM] = train_labels 
num_labels = len(np.unique(train_labels))

## Tokenize and Preprocess
Before training, we tokenize and preprocess the sentence texts to convert them into the format required by transformer model classes.  
The `dataset_from_dataframe` method of the `Processor` class performs the following preprocessing steps and returns a Pytorch `DataSet`
* Tokenize input texts using the tokenizer of the pre-trained model specified by `model_name`. 
* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.
* Pad or truncate the token lists to the specified max length.

In [None]:
processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR, to_lower=TO_LOWER)

train_dataset = processor.dataset_from_dataframe(
    df=train_df,
    text_col=TEXT_COL_1,
    label_col=LABEL_COL_NUM,
    text2_col=TEXT_COL_2,
    max_len=MAX_SEQ_LENGTH,
)
dev_dataset_matched = processor.dataset_from_dataframe(
    df=dev_df_matched,
    text_col=TEXT_COL_1,    
    text2_col=TEXT_COL_2,
    max_len=MAX_SEQ_LENGTH,
)
dev_dataset_mismatched = processor.dataset_from_dataframe(
    df=dev_df_mismatched,
    text_col=TEXT_COL_1,    
    text2_col=TEXT_COL_2,
    max_len=MAX_SEQ_LENGTH,
)

train_dataloader = dataloader_from_dataset(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
dev_dataloader_matched = dataloader_from_dataset(
    dev_dataset_matched, batch_size=BATCH_SIZE, shuffle=False
)
dev_dataloader_mismatched = dataloader_from_dataset(
    dev_dataset_mismatched, batch_size=BATCH_SIZE, shuffle=False
)

## Train and Predict

### Create Classifier

In [None]:
classifier = SequenceClassifier(
    model_name=MODEL_NAME, num_labels=num_labels, cache_dir=CACHE_DIR
)

### Train Classifier

In [None]:
with Timer() as t:
    classifier.fit(
            train_dataloader,
            num_epochs=NUM_EPOCHS,
            learning_rate=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
        )

print("Training time : {:.3f} hrs".format(t.interval / 3600))

### Predict on Test Data

In [None]:
with Timer() as t:
    predictions_matched = classifier.predict(dev_dataloader_matched)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

In [None]:
with Timer() as t:
    predictions_mismatched = classifier.predict(dev_dataloader_mismatched)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

## Evaluate

In [None]:
predictions_matched = label_encoder.inverse_transform(predictions_matched)
print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))

In [None]:
predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)
print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))

## Compare Model Performance

|Model name|Training time|Scoring time|Matched F1|Mismatched F1|
|:--------:|:-----------:|:----------:|:--------:|:-----------:|
|xlnet-large-cased|5.15 hrs|0.11 hrs|0.887|0.890|
|bert-large-cased|4.01 hrs|0.08 hrs|0.867|0.867|

In [None]:
result_matched_dict = classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3, output_dict=True)
result_mismatched_dict = classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3, output_dict=True)
sb.glue("matched_precision", result_matched_dict["weighted avg"]["precision"])
sb.glue("matched_recall", result_matched_dict["weighted avg"]["recall"])
sb.glue("matched_f1", result_matched_dict["weighted avg"]["f1-score"])
sb.glue("mismatched_precision", result_mismatched_dict["weighted avg"]["precision"])
sb.glue("mismatched_recall", result_mismatched_dict["weighted avg"]["recall"])
sb.glue("mismatched_f1", result_mismatched_dict["weighted avg"]["f1-score"])