Multitask training with disjoint datasets

In this tutorial, we will jointly train a classification task with a language modeling task in a multitask setting. The models will share the embedding and representation layers.

We will use the following datasets:

  1. Binarized Stanford Sentiment Treebank (SST-2), which is part of the GLUE benchmark. This dataset contains segments from movie reviews labeled with their binary sentiment.
  2. WikiText-2, a medium-size language modeling dataset with text extracted from Wikipedia.

1. Fetch and prepare the dataset

Download the dataset in a local directory. We will refer to this as base_dir in the next section.

$ curl "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" -o wikitext-2-v1.zip
$ unzip wikitext-2-v1.zip
$ curl "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8" -o SST-2.zip
$ unzip SST-2.zip

Remove headers from SST-2 data:

$ cd base_dir/SST-2
$ sed -i '1d' train.tsv
$ sed -i '1d' dev.tsv

Remove empty lines from WikiText:

$ cd base_dir/wikitext-2
$ sed -i '/^\s*$/d' train.tsv
$ sed -i '/^\s*$/d' valid.tsv
$ sed -i '/^\s*$/d' test.tsv

2. Train a base model

Prepare the configuration file for training. A sample config file for the base document classification model can be found in your PyText repository at demo/configs/sst2.json. If you haven’t set up PyText, please follow Installation, then make the following changes in the config:

  • Set train_path to base_dir/SST-2/train.tsv.
  • Set eval_path to base_dir/SST-2/eval.tsv.
  • Set test_path to base_dir/SST-2/test.tsv.

The test set labels for this tasks are not openly available, therefore we will use the dev set. Train the model using the command below.

(pytext) $ pytext train < demo/configs/sst2.json

The output will look like:

Stage.EVAL
loss: 0.472868
Accuracy: 85.67

3. Configure for multitasking

The example configuration for this tutorial is at demo/configs/multitask_sst_lm.json. The main configuration is under tasks, which is a dictionary of task name to task config:

"tasks": {
  "SST2": {
    "DocClassificationTask": { ... }
  },
  "LM": {
    "LMTask": { ... }
  }
}

The sub-tasks can be configured as you would in a single task setting, with the exception of changes described in the next sections.

Also configure epoch_size under the parent task’s data handler:

"data_handler": {
  "epoch_size": 2000
}

3. Specify which parameters to share

Parameter sharing is specified at module level with the shared_module_key parameter, which is an arbitrary string. Modules with identical shared_module_key share parameters.

Here we will share the BiLSTM module. Under the SST task, we set

"representation": {
  "BiLSTMDocAttention": {
    "lstm": {
      "shared_module_key": "SHARED_LSTM"
    }
  }
}

Under the LM task, we set

"representation": {
  "shared_module_key": "SHARED_LSTM"
},

In this case, BiLSTMDocAttention.lstm of DocClassificationTask and representation of LMTask are both of type BiLSTM, therefore parameter sharing is possible.

3. Share the embedding layer

The embedding is also a module, and can be similarly shared. This is configured under the features section. However, we need to ensure that we use the same vocabulary for both tasks, by specifying a pre-built vocabulary file. First create the vocabulary from the classification task data:

$ cd base_dir/SST-2
$ cat train.tsv dev.tsv | tr ' ' '\n' | sort | uniq > sst_vocab.txt

Then point to this file in configuration:

"features": {
    "shared_module_key": "SHARED_EMBEDDING",
    "word_feat": {
      "vocab_file": "base_dir/SST-2/sst_vocab.txt",
      "vocab_size": 15000,
      "vocab_from_train_data": false
    }
  }

3. Train the model

You can train the model with

(pytext) $ pytext train < demo/configs/multitask_sst_lm.json

The output will look like

Stage.EVAL
loss: 0.455871
Accuracy: 86.12

Not a great improvement, but we used a very primitive language modeling task (bi-directional with no masking) for the purposes of this tutorial. Happy multitasking!