ULMFit

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
In [2]:
from lgm import *

Data

In [3]:
path = untar_data(URLs.IMDB)
In [4]:
# path
In [5]:
import pickle
labeled_list = pickle.load(open(path/'labeled_list_lm.pkl', 'rb'))
In [6]:
batch_size = 64
bptt = 70
data = lm_databunchify(labeled_list, batch_size, bptt)
In [7]:
vocab = labeled_list.train.processor_x[1].vocab
In [8]:
len(vocab)
Out[8]:
60006
In [9]:
print(vocab[:100])
['_UNK_', '_PAD_', '_BOS_', '_EOS_', '_REP_', '_WREP_', '_ALLCAPS_', '_CAP_', 'the', '.', ',', 'and', 'a', 'of', 'to', 'is', 'it', 'in', 'i', 'this', 'that', '"', "'s", '-', '\n\n', 'was', 'as', 'with', 'for', 'movie', 'but', 'film', 'you', ')', 'on', "n't", '(', 'not', 'are', 'he', 'his', 'have', 'one', 'be', 'all', 'at', 'they', 'by', 'an', 'who', '!', 'from', 'so', 'like', 'there', 'or', 'her', 'just', 'do', 'about', 'has', 'out', "'", 'if', 'what', 'some', '?', 'good', 'when', 'more', 'very', 'she', 'up', 'would', 'no', 'time', 'even', 'my', 'can', 'their', 'which', 'only', 'story', 'really', 'see', 'had', 'were', 'did', 'well', 'me', 'we', 'does', '...', 'than', 'much', 'could', ':', 'bad', 'been', 'get']

Finetuning the LM

Before tackling the classification task, we have to finetune our language model to the IMDB corpus.

We start with a model pretrained on Wikipedia (training done on google colab) and its associated Wikipedia-based vocabulary, both of which can be downloaded here:

In [10]:
# !wget https://abrsvn.github.io/files/pretrained.pth -P {path}
# !wget https://abrsvn.github.io/files/vocab.pkl -P {path}
In [11]:
dropout_probs = tensor([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.5
# print(dropout_probs)
tok_pad = vocab.index(PAD)
emb_dim = 300
hidden_dim = 300
n_layers = 2

model = get_language_model(len(vocab), emb_dim, hidden_dim, n_layers, tok_pad, *dropout_probs)

old_wgts  = torch.load(path/'pretrained.pth')
old_vocab = pickle.load(open(path/'vocab.pkl', 'rb'))

It is very unlikely that the ids in the IMDB-based vocabulary correspond to the Wikipedia-based vocabulary we used to pretrain model:

  • the tokens are sorted by their corpus frequency (apart from the special tokens, which were prepended to the list)

Let's look at the word 'house':

In [12]:
house_idx_imdb = vocab.index('house')
print("index of 'house' in the IMDB vocab:", house_idx_imdb)
house_idx_wikipedia = old_vocab.index('house')
print("index of 'house' in the Wikipedia vocab:", house_idx_wikipedia)
index of 'house' in the IMDB vocab: 349
index of 'house' in the Wikipedia vocab: 232

We need to match our pretrained model weights with the new vocab:

  • we do this by rearranging the word embeddings (and the decoder biases, since we tied the embeddings and decoder biases);
  • we might also have new words in IMDB vocab that are not in the pretrained Wikipedia vocab; we assign the mean of the pretrained embeddings to those words.
In [13]:
old_wgts.keys()
Out[13]:
odict_keys(['0.emb.weight', '0.emb_dropout.emb.weight', '0.rnns.0.weight_hh_l0_raw', '0.rnns.0.inner_module.weight_ih_l0', '0.rnns.0.inner_module.weight_hh_l0', '0.rnns.0.inner_module.bias_ih_l0', '0.rnns.0.inner_module.bias_hh_l0', '0.rnns.1.weight_hh_l0_raw', '0.rnns.1.inner_module.weight_ih_l0', '0.rnns.1.inner_module.weight_hh_l0', '0.rnns.1.inner_module.bias_ih_l0', '0.rnns.1.inner_module.bias_hh_l0', '1.decoder.weight', '1.decoder.bias'])
In [14]:
house_wgt  = old_wgts['0.emb.weight'][house_idx_wikipedia]
house_bias = old_wgts['1.decoder.bias'][house_idx_wikipedia]
In [15]:
# def match_embeds(old_wgts, old_vocab, new_vocab):
#     """
#     Matches embeddings from an old_vocab to a new_vocab when transfer learning:
#     -- old_vocab is the vocab associated with the pretrained model
#     -- new_vocab is the vocab associated with the new corpus
#     -- old_wgts are the weights from the old pretrained model (a state dict)
#     We end up with embeddings for the new_vocab that are the same as the old
#     ones whenever an item is both in the new_vocab and in the old_vocab. When an
#     item in the new_vocab is missing from the old_vocab, it is assigned an
#     average embedding.
#     The old_wgts are updated with respect to the relevant layers. The parameters
#     of the other layers are kept the same. The updated old_wgts are returned in
#     full so that they can be loaded into the new model.
#     """
#     wgts = old_wgts['0.emb.weight']
#     bias = old_wgts['1.decoder.bias']
#     # compute mean weights; we'll assign them to new vocab items
#     wgts_m, bias_m = wgts.mean(dim=0), bias.mean()
#     # initialize new weights
#     new_wgts = wgts.new_zeros(len(new_vocab), wgts.size(1))
#     new_bias = bias.new_zeros(len(new_vocab))
#     # reverse old vocab so that we can index into the old weights
#     otoi = {v:k for k,v in enumerate(old_vocab)}
#     # we check every item in the new vocab
#     for i,w in enumerate(new_vocab):
#         # if the item is in the old_vocab, we transfer the old weights
#         if w in otoi:
#             idx = otoi[w]
#             new_wgts[i], new_bias[i] = wgts[idx], bias[idx]
#         # if the item is not in the old_vocab, we give it average weights
#         else: new_wgts[i], new_bias[i] = wgts_m, bias_m
#     old_wgts['0.emb.weight']        = new_wgts
#     old_wgts['0.emb_dropout.emb.weight'] = new_wgts
#     old_wgts['1.decoder.weight']    = new_wgts
#     old_wgts['1.decoder.bias']      = new_bias
#     return old_wgts
In [16]:
wgts = match_embeds(old_wgts, old_vocab, vocab)

Now let's check that the word "house" was properly converted.

In [17]:
test_near(wgts['0.emb.weight'][house_idx_imdb], house_wgt)
test_near(wgts['1.decoder.bias'][house_idx_imdb], house_bias)

We can load the pretrained weights in our model before beginning training.

In [18]:
model.load_state_dict(wgts)
Out[18]:
<All keys matched successfully>

If we want to apply discriminative learning rates, we need to split our model in different layer groups. Let's look at our model:

In [19]:
model
Out[19]:
SequentialRNN(
  (0): AWD_LSTM(
    (emb): Embedding(60006, 300, padding_idx=1)
    (emb_dropout): EmbeddingDropout(
      (emb): Embedding(60006, 300, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (inner_module): LSTM(300, 300, batch_first=True)
      )
      (1): WeightDropout(
        (inner_module): LSTM(300, 300, batch_first=True)
      )
    )
    (input_dropout): RNNDropout()
    (hidden_dropouts): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
    )
  )
  (1): LinearDecoder(
    (output_dropout): RNNDropout()
    (decoder): Linear(in_features=300, out_features=60006, bias=True)
  )
)
In [20]:
# def lm_splitter(model):
#     """
#     Splits the language model provided by the get_language_model into multiple
#     param groups to do transfer learning (e.g., from Wikipedia to IMDB):
#     -- we have one group for each rnn + corresponding dropout, for a
#     total of 2 if we had n_layers = 2 in the get_language_model call;
#     -- we have one final group that contains the embeddings/decoder.
#     The final group needs to be trained the most (new embedding vectors).
#     """
#     groups = []
#     for i in range(len(model[0].rnns)):
#         groups.append(nn.Sequential(model[0].rnns[i], model[0].hidden_dropouts[i]))
#     groups += [nn.Sequential(model[0].emb, model[0].emb_dropout, model[0].input_dropout, model[1])]
#     return [list(group.parameters()) for group in groups]

First we train with the RNNs frozen:

In [21]:
for rnn in model[0].rnns:
    for param in rnn.parameters(): param.requires_grad_(False)
In [22]:
callback_funcs = [partial(GradientClipping, clip=0.1),
                  partial(RNNTrainer, alpha=2., beta=1.)]
In [23]:
learn = Learner(model, data, cross_entropy_flat, adam_opt(),
                metrics=accuracy_flat, callback_funcs=callback_funcs,
                splitter=lm_splitter)
In [24]:
learn.fit(1, callbacks=LRFinder())
epoch train_loss train_accuracy_flat valid_loss valid_accuracy_flat time
In [25]:
lr = 2e-2
callback_sched = sched_1cycle([lr], pct_start=0.5, mom_start=0.8, mom_mid=0.7, mom_end=0.8)
In [26]:
# learn.fit(1, callbacks=callback_sched)
In [27]:
# torch.save(learn.model.state_dict(), path/'finetuned_top_layer.pth')

We then train the whole model with discriminative learning rates:

In [28]:
learn.model.load_state_dict(torch.load(path/'finetuned_top_layer.pth'))
Out[28]:
<All keys matched successfully>
In [29]:
for rnn in model[0].rnns:
    for param in rnn.parameters(): param.requires_grad_(True)
In [30]:
learn.fit(1, callbacks=LRFinder())
epoch train_loss train_accuracy_flat valid_loss valid_accuracy_flat time
In [31]:
lr = 5e-3
# note we have 3 learning rates because we have 3 groups: 2 RNN+dropout and the top embedding layer
callback_sched = sched_1cycle([lr/2., lr/2., lr], pct_start=0.5,
                              mom_start=0.8, mom_mid=0.7, mom_end=0.8)
In [32]:
# learn.fit(10, callbacks=callback_sched)
  • epoch train_loss train_accuracy_flat valid_loss valid_accuracy_flat time
  • 0 4.483242 0.246619 4.390373 0.257883 09:03
  • 1 4.396963 0.254965 4.334490 0.263652 09:04
  • 2 4.333394 0.260960 4.292365 0.267467 09:04
  • 3 4.285608 0.265108 4.264836 0.269886 09:04
  • 4 4.246603 0.268383 4.238844 0.271598 09:04
  • 5 4.211725 0.271098 4.213812 0.274362 09:04
  • 6 4.174984 0.274097 4.190312 0.276666 09:05
  • 7 4.138976 0.277254 4.166958 0.279480 09:04
  • 8 4.109499 0.279853 4.155312 0.280798 09:02
  • 9 4.092328 0.281279 4.152958 0.281090 09:04

We only need to save:

  • the encoder (the first part of the model, i.e., the RNNs)
  • the vocabulary we used

for the classification task. We need to use the same vocab, and we don't need the top layer since that will be replaced by a layer that does binary sentiment classification.

In [33]:
# torch.save(learn.model[0].state_dict(), path/'finetuned_enc.pth')
In [34]:
# pickle.dump(vocab, open(path/'vocab_lm.pkl', 'wb'))

But we also save the full model just in case:

In [35]:
# torch.save(learn.model.state_dict(), path/'finetuned.pth')

Getting the data ready for the classifier

We have to reprocess the data for classification because we have to use the same vocab as the one we had for the finetuned language model.

In [36]:
# vocab = pickle.load(open(path/'vocab_lm.pkl', 'rb'))

# proc_tok = TokenizeProcessor()
# proc_num = NumericalizeProcessor(vocab=vocab) # this is where we use the language-model vocab we saved
# proc_cat = CategoryProcessor()
In [37]:
# textlist = TextList.from_files(path, include=['train', 'test'])
# splitdata = SplitData.split_by_func(textlist, partial(grandparent_splitter, valid_name='test'))
# labeled_list = label_by_func(splitdata, parent_labeler, processor_x = [proc_tok, proc_num], processor_y=proc_cat)
In [38]:
# pickle.dump(labeled_list, open(path/'labeled_list_clas.pkl', 'wb'))
In [39]:
labeled_list = pickle.load(open(path/'labeled_list_clas.pkl', 'rb'))
vocab = pickle.load(open(path/'vocab_lm.pkl', 'rb'))
In [40]:
batch_size = 64
bptt = 70
data = clas_databunchify(labeled_list, batch_size)

Ignore padding

Recall that for classification, we need to feed in batches of documents that are padded (at the end) so that the batch can be fit into a tensor.

Computing on the padding is just a waste. Worse, the information that is useful for classification (the actual movie review) recedes further and further into the past and gets weaker and weaker - even with LSTMs.

We use two pytorch utility functions to ignore the padding in the inputs.

In [41]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

Let's see how this works. We grab a batch from the training set. Actually, let's grab the second batch, since the first one has has reviews of fairly wildly varying lengths (because it includes the longest review).

In [42]:
test_iter = iter(data.train_dl)
x, y = next(test_iter) # first batch
x, y = next(test_iter) # second batch

These are the reviews, i.e., the predictors:

In [43]:
x.size()
Out[43]:
torch.Size([64, 185])

Here are the first 200 words from the second review in the batch:

In [44]:
' '.join(vocab[idx] for idx in x[1][:200])
Out[44]:
"_BOS_ _CAP_ stargate is the best show ever . _CAP_ all the actors are absolutely perfect for there roles . i love the connection between the characters . _CAP_ if you have not seen this show i very highly recommend it . _CAP_ although this program is compared to _CAP_ star trek a lot of the time it actually ca n't be because it is completely different . i am a star trek fan but i would definitely rate this show well above any of the star treks . _CAP_ unfortunately i live in _CAP_ new _CAP_ zealand and we do not get _CAP_ stargate on our tv so if i want to see it i have to buy the dvds and season 10 is not out here yet so i can not see it for quite some time ( which is highly depressing ) . _CAP_ however this program is very very good and is a must see , but be warned it is highly addictive . _CAP_ so in summery i _CAP_ love _CAP_ stargate ( and _CAP_ amanda _CAP_ tapping ) ."

These are the labels (positive / negative sentiment), i.e., the response:

In [45]:
y.size()
Out[45]:
torch.Size([64])
In [46]:
y
Out[46]:
tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1])

The utility functions need the lengths of the reviews to be passed in. They're used after the embedding layer, so we can't see the padding anymore:

In [47]:
lengths = x.size(1) - (x == 1).sum(1)
lengths[:5]
Out[47]:
tensor([185, 185, 185, 185, 185])
In [48]:
len(vocab)
Out[48]:
60006
In [49]:
test_emb = nn.Embedding(len(vocab), 300)
In [50]:
test_emb(x).shape
Out[50]:
torch.Size([64, 185, 300])

We create a PackedSequence object that contains all of our unpadded sequences

In [51]:
packed = pack_padded_sequence(test_emb(x), lengths, batch_first=True)
In [52]:
packed
Out[52]:
PackedSequence(data=tensor([[ 1.8622, -0.4052, -0.6893,  ..., -0.4676,  0.6183,  0.5811],
        [ 1.8622, -0.4052, -0.6893,  ..., -0.4676,  0.6183,  0.5811],
        [ 1.8622, -0.4052, -0.6893,  ..., -0.4676,  0.6183,  0.5811],
        ...,
        [-1.8628,  0.8184,  0.7891,  ...,  0.1025, -1.0998,  0.5846],
        [ 1.2274,  1.2710,  1.5473,  ...,  3.4062,  0.4924,  0.0693],
        [-1.8628,  0.8184,  0.7891,  ...,  0.1025, -1.0998,  0.5846]],
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 54, 43, 29, 15]), sorted_indices=None, unsorted_indices=None)
In [53]:
packed.data.shape
Out[53]:
torch.Size([11725, 300])
In [54]:
len(packed.batch_sizes)
Out[54]:
185

This object can be passed to any RNN directly while retaining the speed of CuDNN.

In [55]:
test = nn.LSTM(300, 300, 2)
In [56]:
y, h = test(packed)
In [57]:
y
Out[57]:
PackedSequence(data=tensor([[-0.0064, -0.0219, -0.0363,  ...,  0.0400, -0.0412, -0.0229],
        [-0.0064, -0.0219, -0.0363,  ...,  0.0400, -0.0412, -0.0229],
        [-0.0064, -0.0219, -0.0363,  ...,  0.0400, -0.0412, -0.0229],
        ...,
        [-0.0101,  0.0592,  0.0369,  ...,  0.0360, -0.0305,  0.0359],
        [ 0.0020, -0.0373, -0.0216,  ..., -0.0100, -0.0465, -0.0121],
        [-0.0060,  0.0568,  0.0263,  ...,  0.0531,  0.0021,  0.0256]],
       grad_fn=<CatBackward>), batch_sizes=tensor([64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
        64, 54, 43, 29, 15]), sorted_indices=None, unsorted_indices=None)

Then we can unpad it with the following function for other modules:

In [58]:
unpack = pad_packed_sequence(y, batch_first=True)
In [59]:
unpack[0].shape
Out[59]:
torch.Size([64, 185, 300])
In [60]:
unpack[1]
Out[60]:
tensor([185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185, 185,
        185, 184, 184, 184, 184, 184, 184, 184, 184, 184, 184, 184, 184, 184,
        184, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183,
        183, 182, 182, 182, 182, 182, 182, 182, 182, 182, 182, 182, 181, 181,
        181, 181, 181, 181, 181, 181, 181, 181])

We need to change our model a little bit to use this.

In [61]:
# class AWD_LSTM1(nn.Module):
#     """
#     AWD-LSTM inspired by https://arxiv.org/abs/1708.02182,
#     updated to deal with pad_packed_sequence and pack_padded_sequence.
#     """
#     initrange=0.1

#     def __init__(self, vocab_size, emb_size, n_hid, n_layers, pad_token,
#                  hidden_prob=0.2, input_prob=0.6, embed_prob=0.1, weight_prob=0.5):
#         super().__init__()
#         self.batch_size = 1
#         self.emb_size = emb_size
#         self.n_hid = n_hid
#         self.n_layers = n_layers
#         self.pad_token = pad_token
#         self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=pad_token)
#         self.emb_dropout = EmbeddingDropout(self.emb, embed_prob)
#         # we create n_layers of LSTMs
#         self.rnns = [nn.LSTM(emb_size if l == 0 else n_hid,
#                              (n_hid if l != n_layers - 1 else emb_size),
#                              1, batch_first=True)
#                      for l in range(n_layers)]
#         self.rnns = nn.ModuleList([WeightDropout(rnn, weight_prob)
#                                    for rnn in self.rnns])
#         self.emb.weight.data.uniform_(-self.initrange, self.initrange)
#         self.input_dropout = RNNDropout(input_prob)
#         self.hidden_dropouts = nn.ModuleList([RNNDropout(hidden_prob)
#                                               for l in range(n_layers)])

#     def forward(self, input):
#         batch_size, seq_len = input.size()
#         if batch_size != self.batch_size:
#             self.batch_size = batch_size
#             self.reset()
#         mask = (input == self.pad_token)
#         lengths = seq_len - mask.long().sum(1)
#         n_empty = (lengths == 0).sum()
#         if n_empty > 0:
#             input = input[:-n_empty]
#             lengths = lengths[:-n_empty]
#             self.hidden = [(h[0][:, :input.size(0)], h[1][:, :input.size(0)])
#                            for h in self.hidden]
#         raw_output = self.input_dropout(self.emb_dropout(input))
#         new_hidden, raw_outputs, outputs = [], [], []
#         for l, (rnn, hid_dropout) in enumerate(zip(self.rnns, self.hidden_dropouts)):
#             # take data of different lengths and shape it to pass to RNN
#             raw_output = pack_padded_sequence(raw_output, lengths, batch_first=True)
#             raw_output, new_h = rnn(raw_output, self.hidden[l])
#             # this is where the padding actually happens
#             raw_output = pad_packed_sequence(raw_output, batch_first=True)[0]
#             raw_outputs.append(raw_output)
#             # we do hidden dropout for all layers but the last one
#             if l != self.n_layers - 1: raw_output = hid_dropout(raw_output)
#             outputs.append(raw_output)
#             new_hidden.append(new_h)
#         self.hidden = to_detach(new_hidden)
#         return raw_outputs, outputs, mask

#     def _one_hidden(self, l):
#         "Return one hidden state."
#         nh = self.n_hid if l != self.n_layers - 1 else self.emb_size
#         return next(self.parameters()).new(1, self.batch_size, nh).zero_()

#     def reset(self):
#         "Reset the hidden states."
#         self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]

Concat pooling

We will use three things for the classification head of the model: the last hidden state, the average of all the hidden states and the maximum of all the hidden states. The trick is just to, once again, ignore the padding in the last element/average/maximum.

In [62]:
# class Pooling(nn.Module):
#     """
#     This is just a pedagogically useful model. The actual pooling classifier is
#     provided in PoolingLinearClassifier below.
#     The LSTMs create hidden states for bptt time steps. We decide what to pass
#     to the classifier here. Following concat pooling from vision, we use three
#     things for the classification head of the model. We concatenate:
#     -- the last hidden state
#     -- the average (mean) pool of all the bptt hidden states
#     -- the max pool of all the bptt hidden states
#     We pass the resulting concatenated tensor to the classifier.
#     """
#     def forward(self, input):
#         raw_outputs, outputs, mask = input
#         # last hidden state
#         output = outputs[-1]
#         # once again, we need to ignore the padding in the last hidden state,
#         # as well as the average pool and max pool tensors
#         lengths = output.size(1) - mask.long().sum(dim=1)
#         # average pool
#         avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
#         avg_pool.div_(lengths.type(avg_pool.dtype)[:, None])
#         # max pool
#         max_pool = output.masked_fill(mask[:, :, None], -float('inf')).max(dim=1)[0]
#         # Concat pooling
#         x = torch.cat([output[torch.arange(0, output.size(0)), lengths-1],
#                        max_pool, avg_pool], 1)
#         return output, x

Let's go through an example:

In [63]:
emb_dim = 300
hidden_dim = 300
n_layers = 2
tok_pad = vocab.index(PAD)

Let's instantiate the encoder:

In [64]:
enc = AWD_LSTM1(len(vocab), emb_dim, hidden_dim, n_layers, pad_token=tok_pad)
pool = Pooling()

enc.batch_size = batch_size
enc.reset()

Let's get a batch of data from the train dataloader and feed predictors (reviews) $x$ through the encoder:

In [65]:
test_iter = iter(data.train_dl)
x, y = next(test_iter) # keep the first batch because it's easiest to see padding
output, c = pool(enc(x))

We can check we have padding with 1s at the end of each text (except the first which is the longest).

In [66]:
x
Out[66]:
tensor([[    2,     7,  1148,  ...,    12, 15754,    24],
        [    2,     7,   814,  ...,     1,     1,     1],
        [    2,    21,     7,  ...,     1,     1,     1],
        ...,
        [    2,     7,   283,  ...,     1,     1,     1],
        [    2,     7,    16,  ...,     1,     1,     1],
        [    2,    18,   257,  ...,     1,     1,     1]])
In [67]:
lengths = x.size(1) - (x == 1).sum(1)
lengths[:5]
Out[67]:
tensor([3310, 1553, 1364, 1352, 1332])
In [68]:
x[1]
Out[68]:
tensor([  2,   7, 814,  ...,   1,   1,   1])

Pytorch puts $0$s everywhere we had padding in the output when unpacking:

In [69]:
print(output[1])
print(output[1].shape)
tensor([[ 0.0182, -0.0076, -0.0086,  ..., -0.0004,  0.0063,  0.0256],
        [ 0.0285, -0.0161, -0.0061,  ..., -0.0023,  0.0051,  0.0303],
        [ 0.0307, -0.0238, -0.0051,  ..., -0.0012,  0.0064,  0.0335],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SelectBackward>)
torch.Size([3310, 300])
In [70]:
print(x[1][-100:])
print(output[1][-100:])
print(output[1][-100:].shape)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SliceBackward>)
torch.Size([100, 300])

We can actually test that padding in the input x got replaced with $0$s in the output for all samples in the batch:

In [71]:
(x==tok_pad).float()
Out[71]:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]])
In [72]:
(output.sum(dim=2) == 0).float()
Out[72]:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]])
In [73]:
test_near((output.sum(dim=2) == 0).float(), (x==tok_pad).float())

So the last hidden state isn't the last element of output. We need to go back in the sequence to the last point before the padding.

Let's check we got everything right.

In [74]:
i = 1
length = x.size(1) - (x[i]==1).long().sum(); print(length)
out_unpad = output[i, :length]
print(out_unpad[-1][:20]); print(c[i,:20])
tensor(1553)
tensor([ 0.0335, -0.0336, -0.0065,  0.0238,  0.0303, -0.0043, -0.0048, -0.0142,
         0.0011,  0.0118, -0.0074, -0.0067, -0.0186,  0.0220, -0.0059,  0.0133,
         0.0291, -0.0035,  0.0401,  0.0276], grad_fn=<SliceBackward>)
tensor([ 0.0335, -0.0336, -0.0065,  0.0238,  0.0303, -0.0043, -0.0048, -0.0142,
         0.0011,  0.0118, -0.0074, -0.0067, -0.0186,  0.0220, -0.0059,  0.0133,
         0.0291, -0.0035,  0.0401,  0.0276], grad_fn=<SliceBackward>)
In [75]:
print(out_unpad.max(0)[0][:20]); print(c[i, 300:320])
tensor([ 0.0486, -0.0076,  0.0060,  0.0358,  0.0500,  0.0138,  0.0015,  0.0061,
         0.0167,  0.0259,  0.0086,  0.0103, -0.0015,  0.0392,  0.0067,  0.0324,
         0.0349,  0.0150,  0.0525,  0.0380], grad_fn=<SliceBackward>)
tensor([ 0.0486, -0.0076,  0.0060,  0.0358,  0.0500,  0.0138,  0.0015,  0.0061,
         0.0167,  0.0259,  0.0086,  0.0103, -0.0015,  0.0392,  0.0067,  0.0324,
         0.0349,  0.0150,  0.0525,  0.0380], grad_fn=<SliceBackward>)
In [76]:
print(out_unpad.mean(0)[:20]); print(c[i,600:620])
tensor([ 0.0360, -0.0374, -0.0052,  0.0212,  0.0324, -0.0028, -0.0111, -0.0062,
         0.0015,  0.0128, -0.0045, -0.0035, -0.0207,  0.0265, -0.0062,  0.0161,
         0.0211, -0.0001,  0.0376,  0.0262], grad_fn=<SliceBackward>)
tensor([ 0.0360, -0.0374, -0.0052,  0.0212,  0.0324, -0.0028, -0.0111, -0.0062,
         0.0015,  0.0128, -0.0045, -0.0035, -0.0207,  0.0265, -0.0062,  0.0161,
         0.0211, -0.0001,  0.0376,  0.0262], grad_fn=<SliceBackward>)
In [77]:
for i in range(batch_size):
    length = x.size(1) - (x[i]==1).long().sum()
    out_unpad = output[i, :length]
    test_near(out_unpad[-1], c[i, :300])
    test_near(out_unpad.max(0)[0], c[i, 300:600])
    test_near(out_unpad.mean(0), c[i, 600:])

Our pooling layer properly ignores the padding, so now let's add it to the classifier.

In [78]:
# class PoolingLinearClassifier(nn.Module):
#     """
#     Create a linear classifier with pooling:
#     -- the concat pooling layer, followed by
#     -- a list of batchnorm + dropout + linear + activation layers
#     """
#     def __init__(self, layers, dropout_probs):
#         super().__init__()
#         modified_layers = []
#         activations = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]
#         # list of batchnorm + dropout + linear layers
#         for n_in, n_out, dropout_prob, activation in zip(layers[:-1], layers[1:],
#                                                          dropout_probs, activations):
#             modified_layers += batchnorm_dropout_linear(n_in, n_out,
#                                                         dropout_prob=dropout_prob,
#                                                         activation=activation)
#         self.layers = nn.Sequential(*modified_layers)

#     def forward(self, input):
#         """
#         The LSTMs create hidden states for bptt time steps. We decide what to pass
#         to the classifier here. Following concat pooling from vision, we use three
#         things for the classification head of the model. We concatenate:
#         -- the last hidden state
#         -- the average (mean) pool of all the bptt hidden states
#         -- the max pool of all the bptt hidden states
#         We pass the resulting concatenated tensor to the linear classifier.
#         """
#         raw_outputs, outputs, mask = input
#         # last hidden state
#         output = outputs[-1]
#         # we need to ignore the padding in the last hidden state,
#         # as well as the average pool and max pool tensors
#         lengths = output.size(1) - mask.long().sum(dim=1)
#         # average pool
#         avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
#         avg_pool.div_(lengths.type(avg_pool.dtype)[:, None])
#         # max pool
#         max_pool = output.masked_fill(mask[:, :, None], -float('inf')).max(dim=1)[0]
#         # Concat pooling
#         x = torch.cat([output[torch.arange(0, output.size(0)), lengths-1],
#                        max_pool, avg_pool], 1)
#         # pass the concat-pooled tensor through the linear layers
#         x = self.layers(x)
#         return x

Breaking reviews into chunks of bptt length

Then we just have to feed our texts through those two blocks.

But we can't give them all at once to the AWD_LSTM1 or we might get an out-of-memory error:

  • we'll go for chunks of bptt length to regularly detach the history of our hidden states
In [79]:
# def pad_tensor(t, batch_size, val=0.):
#     if t.size(0) < batch_size:
#         return torch.cat([t, val + t.new_zeros(batch_size-t.size(0), *t.shape[1:])])
#     return t

# class SentenceEncoder(nn.Module):
#     "The encoder is the AWD LSTM model that gets called on the input text."
#     def __init__(self, encoder, bptt, pad_idx=1):
#         super().__init__()
#         self.bptt = bptt
#         self.encoder = encoder
#         self.pad_idx = pad_idx

#     def concat(self, arrs, batch_size):
#         return [torch.cat([pad_tensor(l[si],batch_size) for l in arrs], dim=1)
#                 for si in range(len(arrs[0]))]

#     def forward(self, input):
#         batch_size, seq_len = input.size()
#         self.encoder.batch_size = batch_size
#         self.encoder.reset()
#         raw_outputs, outputs, masks = [], [], []
#         # We go through the input one bptt at a time
#         for i in range(0, seq_len, self.bptt):
#             # we call the RNN model on it
#             r, o, m = self.encoder(input[:,i: min(i+self.bptt, seq_len)])
#             # we keep appending the results
#             masks.append(pad_tensor(m, batch_size, 1))
#             raw_outputs.append(r)
#             outputs.append(o)
#         return self.concat(raw_outputs, batch_size), self.concat(outputs, batch_size),torch.cat(masks,dim=1)

# def get_text_classifier(vocab_sz, emb_sz, n_hid, n_layers, n_out, pad_token,
#                         bptt, output_p=0.4, hidden_p=0.2, input_p=0.6,
#                         embed_p=0.1, weight_p=0.5, layers=None, drops=None):
#     "To create a full AWD-LSTM"
#     rnn_enc = AWD_LSTM1(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers,
#                         pad_token=pad_token, hidden_p=hidden_p, input_p=input_p,
#                         embed_p=embed_p, weight_p=weight_p)
#     enc = SentenceEncoder(rnn_enc, bptt)
#     if layers is None:
#         layers = [50]
#     if drops is None:
#         drops = [0.1] * len(layers)
#     layers = [3 * emb_sz] + layers + [n_out]
#     drops = [output_p] + drops
#     return SequentialRNN(enc, PoolingLinearClassifier(layers, drops))
In [80]:
emb_dim = 300
hidden_dim = 300
n_layers = 2
dropout_probs = tensor([0.4, 0.3, 0.4, 0.05, 0.5]) * 0.25
model = get_text_classifier(len(vocab), emb_dim, hidden_dim, n_layers, 2, 1, bptt, *dropout_probs)

Training

We load our pretrained encoder and freeze it.

In [81]:
model[0].encoder.load_state_dict(torch.load(path/'finetuned_enc.pth'))
Out[81]:
<All keys matched successfully>

Let's take a look at the model:

In [82]:
model
Out[82]:
SequentialRNN(
  (0): SentenceEncoder(
    (encoder): AWD_LSTM1(
      (emb): Embedding(60006, 300, padding_idx=1)
      (emb_dropout): EmbeddingDropout(
        (emb): Embedding(60006, 300, padding_idx=1)
      )
      (rnns): ModuleList(
        (0): WeightDropout(
          (inner_module): LSTM(300, 300, batch_first=True)
        )
        (1): WeightDropout(
          (inner_module): LSTM(300, 300, batch_first=True)
        )
      )
      (input_dropout): RNNDropout()
      (hidden_dropouts): ModuleList(
        (0): RNNDropout()
        (1): RNNDropout()
      )
    )
  )
  (1): PoolingLinearClassifier(
    (layers): Sequential(
      (0): BatchNorm1d(900, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Dropout(p=0.10000000149011612, inplace=False)
      (2): Linear(in_features=900, out_features=50, bias=True)
      (3): ReLU(inplace=True)
      (4): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Dropout(p=0.1, inplace=False)
      (6): Linear(in_features=50, out_features=2, bias=True)
    )
  )
)
In [83]:
model[0]
Out[83]:
SentenceEncoder(
  (encoder): AWD_LSTM1(
    (emb): Embedding(60006, 300, padding_idx=1)
    (emb_dropout): EmbeddingDropout(
      (emb): Embedding(60006, 300, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (inner_module): LSTM(300, 300, batch_first=True)
      )
      (1): WeightDropout(
        (inner_module): LSTM(300, 300, batch_first=True)
      )
    )
    (input_dropout): RNNDropout()
    (hidden_dropouts): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
    )
  )
)
In [84]:
model[1]
Out[84]:
PoolingLinearClassifier(
  (layers): Sequential(
    (0): BatchNorm1d(900, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Dropout(p=0.10000000149011612, inplace=False)
    (2): Linear(in_features=900, out_features=50, bias=True)
    (3): ReLU(inplace=True)
    (4): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=50, out_features=2, bias=True)
  )
)

We freeze the encoder completely:

In [85]:
for p in model[0].parameters():
    p.requires_grad_(False)

We split the model into groups for discriminative learning rates:

In [86]:
# def class_splitter(model):
#     enc = model[0].encoder
#     groups = [nn.Sequential(enc.emb, enc.emb_dropout, enc.input_dropout)]
#     for i in range(len(enc.rnns)):
#         groups.append(nn.Sequential(enc.rnns[i], enc.hidden_dropouts[i]))
#     groups.append(model[1])
#     return [list(group.parameters()) for group in groups]
In [87]:
param_groups = class_splitter(model)
len(param_groups)
Out[87]:
4

We are now ready to train the top layers (the decoder):

In [88]:
callback_funcs = [partial(GradientClipping, clip=0.1)]
In [89]:
learn = Learner(model, data, F.cross_entropy, opt_func=adam_opt(),
                metrics=accuracy, callback_funcs=callback_funcs, splitter=class_splitter)
In [90]:
learn.fit(1, callbacks=LRFinder())
epoch train_loss train_accuracy valid_loss valid_accuracy time
In [91]:
lr = 1e-2
callback_sched = sched_1cycle(lr, mom_start=0.8, mom_mid=0.7, mom_end=0.8)
In [92]:
learn.fit(1, callbacks=callback_sched)
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 0.327614 0.860280 0.251765 0.895760 00:41
In [93]:
learn.plotter.plot_lr()
In [94]:
learn.plotter.plot_train_stats()

We progressively unfreeze the model to avoid catastrophic forgetting:

  • we now also unfreeze the RNN layer right before the top (decoder) layers
In [95]:
for p in model[0].encoder.rnns[-1].parameters():
    p.requires_grad_(True)

And we train the top RNN together with the decoder:

In [96]:
learn.fit(1, callbacks=LRFinder())
epoch train_loss train_accuracy valid_loss valid_accuracy time
In [97]:
lr = 5e-3
# we have 4 learning rates because we have 4 parameter groups
# the ones for the deeper layers are smaller to avoid catastrophic forgetting
callback_sched = sched_1cycle([lr/2., lr/2., lr/2., lr], mom_start=0.8, mom_mid=0.7, mom_end=0.8)
In [98]:
learn.fit(1, callbacks=callback_sched)
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 0.262747 0.894080 0.207249 0.919000 00:51
In [99]:
learn.plotter.plot_lr()
In [100]:
learn.plotter.plot_train_stats()

We now unfreeze the entire model and train it:

In [101]:
for p in model[0].parameters():
    p.requires_grad_(True)
In [102]:
learn.fit(1, callbacks=LRFinder())
epoch train_loss train_accuracy valid_loss valid_accuracy time
  • note how the learning rates tend to decrease as we progressively unfreeze more of the model
  • this is a good thing: we finetune the model while avoiding catastrophic forgetting
In [103]:
lr = 1e-3
# again, 4 learning rates; the ones for the deeper layers are smaller
callback_sched = sched_1cycle([lr/8., lr/4., lr/2., lr], mom_start=0.8, mom_mid=0.7, mom_end=0.8)
In [104]:
learn.fit(2, callbacks=callback_sched)
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 0.216340 0.914080 0.207779 0.917400 01:03
1 0.196552 0.922920 0.199903 0.920560 01:03
In [105]:
learn.plotter.plot_lr()
In [106]:
learn.plotter.plot_train_stats()

Predicting on the padded batch or on the individual unpadded samples give the same results.

In [108]:
x, y = next(iter(data.valid_dl))
In [109]:
pred_batch = learn.model.eval()(x.cuda())
In [110]:
pred_batch.size()
Out[110]:
torch.Size([128, 2])
In [111]:
pred_batch
Out[111]:
tensor([[ 1.3535e-01,  3.2772e-01],
        [ 2.9845e+00, -1.9178e+00],
        [ 6.7030e-01, -3.2756e-01],
        [ 7.5973e-01, -1.8852e-01],
        [-2.0722e+00,  1.7040e+00],
        [-6.3991e-01,  1.4940e+00],
        [-2.2010e+00,  1.5512e+00],
        [-1.4789e+00,  1.0488e+00],
        [ 4.3468e-01,  8.8996e-02],
        [-1.1900e+00,  8.4418e-01],
        [ 2.5939e+00, -1.6209e+00],
        [-4.3401e-01,  2.6088e-01],
        [ 1.1534e+00, -8.6798e-01],
        [ 3.6046e+00, -2.0541e+00],
        [ 1.3746e+00, -4.9832e-01],
        [-1.4090e+00,  8.5005e-01],
        [ 3.5180e+00, -3.0252e+00],
        [-1.2408e-03, -9.3621e-02],
        [-2.6801e+00,  2.2704e+00],
        [-1.4588e+00,  1.0334e+00],
        [ 1.3757e+00, -9.1967e-01],
        [ 3.4891e-01, -3.1098e-01],
        [-5.4034e+00,  4.3137e+00],
        [-4.4778e-01,  5.0622e-01],
        [-5.0662e-01,  3.8293e-01],
        [ 3.0709e+00, -1.8188e+00],
        [-4.2790e-01,  9.7611e-01],
        [-4.4357e+00,  2.9466e+00],
        [-3.6061e-01,  2.0903e-01],
        [-3.7119e+00,  3.9092e+00],
        [-2.2371e+00,  1.7179e+00],
        [-3.2041e-01,  4.0591e-01],
        [ 2.3604e+00, -1.8609e+00],
        [ 1.5818e+00, -8.7654e-01],
        [-1.0812e+00,  1.0852e+00],
        [-1.2983e+00,  7.4019e-01],
        [ 2.8193e+00, -1.6907e+00],
        [-1.3065e+00,  1.2349e+00],
        [ 1.2096e+00, -5.3791e-01],
        [-1.8449e+00,  1.0396e+00],
        [ 4.5276e+00, -3.3511e+00],
        [ 2.0460e+00, -1.4679e+00],
        [ 2.0462e+00, -1.4239e+00],
        [ 3.4049e+00, -2.0932e+00],
        [-4.3105e+00,  3.8237e+00],
        [ 7.6243e-01, -5.1338e-01],
        [-3.4615e+00,  2.3647e+00],
        [-1.2723e+00,  1.5573e+00],
        [ 3.9520e-01,  3.1757e-01],
        [-3.2667e-01,  6.3606e-01],
        [ 1.4342e+00, -8.6344e-01],
        [-6.7682e-02,  3.0034e-01],
        [ 2.6884e+00, -1.5239e+00],
        [ 1.5570e+00, -9.7936e-01],
        [-4.1181e+00,  3.2673e+00],
        [-5.1375e+00,  4.4899e+00],
        [ 3.5257e+00, -1.8482e+00],
        [-1.4365e+00,  1.8067e+00],
        [-1.6306e+00,  1.9021e+00],
        [ 4.3589e+00, -2.7814e+00],
        [ 1.8153e+00, -1.1702e+00],
        [-9.6126e-01,  1.3603e+00],
        [-1.4448e+00,  1.3107e+00],
        [ 3.0942e+00, -1.8630e+00],
        [-2.0114e+00,  2.2911e+00],
        [ 4.7640e-01, -2.6475e-02],
        [-1.4076e+00,  1.7994e+00],
        [-1.0132e+00,  1.2110e+00],
        [-1.6178e+00,  6.6620e-01],
        [ 2.5505e+00, -1.3528e+00],
        [ 2.9858e+00, -1.9059e+00],
        [-1.2600e+00,  1.9788e+00],
        [-3.7660e+00,  3.9004e+00],
        [ 2.4014e+00, -1.6949e+00],
        [ 3.5573e+00, -2.5081e+00],
        [ 2.2758e+00, -9.6172e-01],
        [ 1.6873e+00, -8.0164e-01],
        [ 2.8675e+00, -1.6632e+00],
        [ 1.7039e+00, -1.1905e+00],
        [ 2.6897e+00, -1.6673e+00],
        [-2.5606e+00,  2.1530e+00],
        [ 4.4244e+00, -3.1763e+00],
        [-2.1791e-01,  9.1675e-02],
        [ 4.3978e+00, -2.8252e+00],
        [ 1.4376e+00, -6.6072e-01],
        [-3.3181e+00,  3.1639e+00],
        [-9.9983e-01,  1.2900e+00],
        [ 2.1572e+00, -1.6157e+00],
        [ 2.5197e+00, -1.5544e+00],
        [-3.1164e+00,  2.7260e+00],
        [ 3.4186e+00, -2.2749e+00],
        [-7.8259e-01,  1.2208e+00],
        [-2.1895e+00,  2.3892e+00],
        [ 1.0197e+00, -7.0633e-01],
        [ 2.9144e+00, -1.5096e+00],
        [ 1.6673e-01,  4.5312e-01],
        [ 4.3101e+00, -2.9105e+00],
        [ 2.0776e+00, -9.7179e-01],
        [-1.5423e+00,  1.4909e+00],
        [ 2.4897e+00, -1.5245e+00],
        [ 4.6749e+00, -3.0700e+00],
        [-5.7461e-01,  7.2697e-01],
        [ 1.4400e+00, -1.0327e+00],
        [-1.7944e+00,  1.3567e+00],
        [-2.2011e+00,  1.8580e+00],
        [-2.2676e+00,  1.5955e+00],
        [-1.6785e+00,  1.4949e+00],
        [ 2.7442e+00, -1.2942e+00],
        [ 4.6620e+00, -3.0061e+00],
        [-1.5535e+00,  1.7363e+00],
        [ 3.4595e+00, -2.6519e+00],
        [-2.5505e+00,  2.4932e+00],
        [ 2.0534e+00, -1.7912e+00],
        [-1.9399e+00,  1.7546e+00],
        [ 3.4902e+00, -2.1904e+00],
        [ 1.2731e+00, -5.5732e-01],
        [-3.0629e-01,  2.6602e-01],
        [ 2.5605e+00, -1.3375e+00],
        [ 1.0649e+00, -4.9703e-01],
        [ 2.0728e+00, -1.2223e+00],
        [ 1.6398e+00, -1.0408e+00],
        [ 3.5617e+00, -2.1479e+00],
        [-1.5644e+00,  1.6653e+00],
        [ 3.1227e+00, -1.9036e+00],
        [ 2.8902e+00, -1.8164e+00],
        [ 3.6213e+00, -2.1962e+00],
        [ 3.9181e+00, -2.5728e+00],
        [ 2.0978e+00, -1.5763e+00]], device='cuda:0', grad_fn=<AddmmBackward>)
In [112]:
pred_logits = []
for review in x:
    length = x.size(1) - (review == 1).long().sum()
    review = review[:length]
    pred_logits.append(learn.model.eval()(review[None].cuda()))
In [113]:
pred_logits
Out[113]:
[tensor([[0.1353, 0.3277]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.9845, -1.9178]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 0.6703, -0.3276]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 0.7597, -0.1885]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.0722,  1.7040]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.6399,  1.4940]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.2010,  1.5512]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4789,  1.0488]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[0.4347, 0.0890]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.1900,  0.8442]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.5939, -1.6209]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.4340,  0.2609]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.1534, -0.8680]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.6046, -2.0541]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.3746, -0.4983]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4090,  0.8500]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.5180, -3.0252]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.0012, -0.0936]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.6801,  2.2704]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4588,  1.0334]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.3757, -0.9197]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 0.3489, -0.3110]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-5.4034,  4.3137]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.4478,  0.5062]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.5066,  0.3829]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.0709, -1.8188]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.4279,  0.9761]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-4.4357,  2.9466]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.3606,  0.2090]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-3.7119,  3.9092]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.2371,  1.7179]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.3204,  0.4059]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.3604, -1.8609]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.5818, -0.8765]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.0812,  1.0852]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.2983,  0.7402]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.8193, -1.6907]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.3065,  1.2349]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.2096, -0.5379]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.8449,  1.0396]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.5276, -3.3511]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0460, -1.4679]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0462, -1.4239]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.4049, -2.0932]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-4.3105,  3.8237]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 0.7624, -0.5134]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-3.4615,  2.3647]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.2723,  1.5573]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[0.3952, 0.3176]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.3267,  0.6361]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.4342, -0.8634]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.0677,  0.3003]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.6884, -1.5239]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.5570, -0.9794]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-4.1181,  3.2673]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-5.1375,  4.4899]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.5257, -1.8482]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4365,  1.8067]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.6306,  1.9021]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.3589, -2.7814]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.8153, -1.1702]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.9613,  1.3603]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4448,  1.3107]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.0942, -1.8630]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.0114,  2.2911]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 0.4764, -0.0265]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.4076,  1.7994]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.0132,  1.2110]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.6178,  0.6662]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.5505, -1.3528]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.9858, -1.9059]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.2600,  1.9788]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-3.7660,  3.9004]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.4014, -1.6949]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.5573, -2.5081]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.2758, -0.9617]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.6873, -0.8016]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.8675, -1.6632]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.7039, -1.1905]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.6897, -1.6673]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.5606,  2.1530]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.4244, -3.1763]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.2179,  0.0917]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.3978, -2.8252]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.4376, -0.6607]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-3.3181,  3.1639]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.9998,  1.2900]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.1572, -1.6157]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.5197, -1.5544]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-3.1164,  2.7260]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.4186, -2.2749]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.7826,  1.2208]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.1895,  2.3892]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.0197, -0.7063]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.9144, -1.5096]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[0.1667, 0.4531]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.3101, -2.9105]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0776, -0.9718]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.5423,  1.4909]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.4897, -1.5245]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.6749, -3.0700]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.5746,  0.7270]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.4400, -1.0327]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.7944,  1.3567]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.2011,  1.8580]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.2676,  1.5955]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.6785,  1.4949]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.7442, -1.2942]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 4.6620, -3.0061]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.5535,  1.7363]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.4595, -2.6519]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-2.5505,  2.4932]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0534, -1.7912]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.9399,  1.7546]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.4902, -2.1904]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.2731, -0.5573]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-0.3063,  0.2660]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.5605, -1.3375]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.0649, -0.4970]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0728, -1.2223]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 1.6398, -1.0408]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.5617, -2.1479]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[-1.5644,  1.6653]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.1227, -1.9036]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.8902, -1.8164]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.6213, -2.1962]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 3.9181, -2.5728]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor([[ 2.0978, -1.5763]], device='cuda:0', grad_fn=<AddmmBackward>)]
In [114]:
assert near(pred_batch, torch.cat(pred_logits))
In [115]:
accuracy(pred_batch, y.cuda())
Out[115]:
tensor(0.9062, device='cuda:0')
In [117]:
accuracy(torch.cat(pred_logits), y.cuda())
Out[117]:
tensor(0.9062, device='cuda:0')
In [ ]: