{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "*Copyright (c) Microsoft Corporation. All rights reserved.*\n", "\n", "*Licensed under the MIT License.*\n", "\n", "# Text Classification of Multi Language Datasets using Transformer Model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import scrapbook as sb\n", "import pandas as pd\n", "import torch\n", "import numpy as np\n", "\n", "from tempfile import TemporaryDirectory\n", "from utils_nlp.common.timer import Timer\n", "from sklearn.metrics import classification_report\n", "from utils_nlp.models.transformers.sequence_classification import SequenceClassifier\n", "\n", "from utils_nlp.dataset import multinli\n", "from utils_nlp.dataset import dac\n", "from utils_nlp.dataset import bbc_hindi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "In this notebook, we fine-tune and evaluate a pretrained Transformer model using BERT earchitecture on three different language datasets:\n", "\n", "- [MultiNLI dataset](https://www.nyu.edu/projects/bowman/multinli/): The Multi-Genre NLI corpus, in English\n", "- [DAC dataset](https://data.mendeley.com/datasets/v524p5dhpj/2): DataSet for Arabic Classification corpus, in Arabic\n", "- [BBC Hindi dataset](https://github.com/NirantK/hindi2vec/releases/tag/bbc-hindi-v0.1): BBC Hindi News corpus, in Hindi\n", "\n", "If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. You can also choose a dataset from three existing datasets (**`MultNLI`**, **`DAC`**, and **`BBC Hindi`**) to experiment. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Running Time\n", "\n", "The table below provides some reference running time on different datasets. \n", "\n", "|Dataset|QUICK_RUN|Machine Configurations|Running time|\n", "|:------|:---------|:----------------------|:------------|\n", "|MultiNLI|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 8 minutes |\n", "|MultiNLI|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.7 hours |\n", "|DAC|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 13 minutes |\n", "|DAC|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.6 hours |\n", "|BBC Hindi|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute |\n", "|BBC Hindi|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 14 minutes |\n", "\n", "If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `batch_size` and `max_len` in `CONFIG`, but note that model performance may be compromised. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "# Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", "QUICK_RUN = True\n", "\n", "# the dataset you want to try, valid values are: \"multinli\", \"dac\", \"bbc-hindi\"\n", "USE_DATASET = \"dac\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For text classification, the following pretrained models are supported." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | model_name | \n", "
---|---|
0 | \n", "bert-base-uncased | \n", "
1 | \n", "bert-large-uncased | \n", "
2 | \n", "bert-base-cased | \n", "
3 | \n", "bert-large-cased | \n", "
4 | \n", "bert-base-multilingual-uncased | \n", "
5 | \n", "bert-base-multilingual-cased | \n", "
6 | \n", "bert-base-chinese | \n", "
7 | \n", "bert-base-german-cased | \n", "
8 | \n", "bert-large-uncased-whole-word-masking | \n", "
9 | \n", "bert-large-cased-whole-word-masking | \n", "
10 | \n", "bert-large-uncased-whole-word-masking-finetune... | \n", "
11 | \n", "bert-large-cased-whole-word-masking-finetuned-... | \n", "
12 | \n", "bert-base-cased-finetuned-mrpc | \n", "
13 | \n", "bert-base-german-dbmdz-cased | \n", "
14 | \n", "bert-base-german-dbmdz-uncased | \n", "
15 | \n", "bert-base-japanese | \n", "
16 | \n", "bert-base-japanese-whole-word-masking | \n", "
17 | \n", "bert-base-japanese-char | \n", "
18 | \n", "bert-base-japanese-char-whole-word-masking | \n", "
19 | \n", "bert-base-finnish-cased-v1 | \n", "
20 | \n", "bert-base-finnish-uncased-v1 | \n", "
21 | \n", "roberta-base | \n", "
22 | \n", "roberta-large | \n", "
23 | \n", "roberta-large-mnli | \n", "
24 | \n", "distilroberta-base | \n", "
25 | \n", "roberta-base-openai-detector | \n", "
26 | \n", "roberta-large-openai-detector | \n", "
27 | \n", "xlnet-base-cased | \n", "
28 | \n", "xlnet-large-cased | \n", "
29 | \n", "distilbert-base-uncased | \n", "
30 | \n", "distilbert-base-uncased-distilled-squad | \n", "
31 | \n", "distilbert-base-german-cased | \n", "
32 | \n", "distilbert-base-multilingual-cased | \n", "
33 | \n", "albert-base-v1 | \n", "
34 | \n", "albert-large-v1 | \n", "
35 | \n", "albert-xlarge-v1 | \n", "
36 | \n", "albert-xxlarge-v1 | \n", "
37 | \n", "albert-base-v2 | \n", "
38 | \n", "albert-large-v2 | \n", "
39 | \n", "albert-xlarge-v2 | \n", "
40 | \n", "albert-xxlarge-v2 | \n", "