%load_ext autoreload
%autoreload 2
%matplotlib inline
from lgm import *
# path = URLs.LOCAL_PATH/'data'
# !wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-{version}-v1.zip -P {path}
# !unzip -q -n {path}/wikitext-{version}-v1.zip -d {path}
# !mv {path}/wikitext-{version}/wiki.train.tokens {path}/wikitext-{version}/train.txt
# !mv {path}/wikitext-{version}/wiki.valid.tokens {path}/wikitext-{version}/valid.txt
# !mv {path}/wikitext-{version}/wiki.test.tokens {path}/wikitext-{version}/test.txt
We split the Wikipedia texts into articles:
path = URLs.LOCAL_PATH/'data'/'wikitext-103'
def istitle(line):
return len(re.findall(r'^ = [^=]* = $', line)) != 0
def read_wiki(filename):
articles = []
with open(filename, encoding='utf8') as f:
lines = f.readlines()
current_article = ''
for i, line in enumerate(lines):
current_article += line
if i < len(lines)-2 and lines[i+1] == ' \n' and istitle(lines[i+2]):
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
current_article = ''
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
return articles
Create train and validation sets:
'test.txt' set, but we won't worry about it.train = TextList(read_wiki(path/'train.txt'), path=path)
valid = TextList(read_wiki(path/'valid.txt'), path=path)
len(train), len(valid)
# splitdata = SplitData(train, valid)
# proc_tok = TokenizeProcessor()
# proc_num = NumericalizeProcessor()
# labeled_list = label_by_func(splitdata, lambda x: 0, processor_x = [proc_tok, proc_num])
import pickle
# pickle.dump(labeled_list, open(path/'labeled_list_wikipedia.pkl', 'wb'))
labeled_list = pickle.load(open(path/'labeled_list_wikipedia.pkl', 'rb'))
batch_size = 64
bptt = 70
data = lm_databunchify(labeled_list, batch_size, bptt)
vocab = labeled_list.train.processor_x[-1].vocab
len(vocab)
print(vocab[:100])
vocab.index(PAD)
The hyper-params and the model:
# import numpy as np
dropout_probs = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2
tok_pad = vocab.index(PAD)
emb_dim = 300
hidden_dim = 300
n_layers = 2 # number of stacked LSTM layers
model = get_language_model(len(vocab), emb_dim, hidden_dim, n_layers, tok_pad, *dropout_probs)
Additional training constraints (gradient clipping and activation regularizations):
callback_funcs = [partial(GradientClipping, clip=0.1),
partial(RNNTrainer, alpha=2., beta=1.)]
learn = Learner(model, data, cross_entropy_flat, adam_opt(),
metrics=accuracy_flat, callback_funcs=callback_funcs)
We find a good learning rate:
learn.fit(1, callbacks=LRFinder())
We set up the hyper-param scheduler (learning rate and momenta for Adam):
lr = 5e-3
callback_sched = sched_1cycle(lr, pct_start=0.3, mom_start=0.8, mom_mid=0.7, mom_end=0.8)
Finally, we train the language model on Wikipedia:
# learn.fit(1, callbacks=callback_sched)
# torch.save(learn.model.state_dict(), path/'pretrained.pth')
# pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))
# %load_ext autoreload
# %autoreload 2
# %matplotlib inline
# !pip3 install --upgrade lgm
# from lgm import *
Mounting local drive directory:
# from google.colab import drive
# drive.mount('/content/drive')
# !nvidia-smi
# torch.cuda.get_device_name(torch.cuda.current_device())
# URLs.LOCAL_PATH = Path('/content/drive/My Drive/lgm_colab_exps'); print(URLs.LOCAL_PATH)
# path = URLs.LOCAL_PATH/'data'/'wikitext-103'
# import pickle
# ll = pickle.load(open(path/'labeled_list_wikipedia.pkl', 'rb'))
# bs, bptt = 64, 70
# data = lm_databunchify(ll, bs, bptt)
# vocab = ll.train.processor_x[-1].vocab
# pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))
# len(vocab)
# print(vocab)
# emb_sz, nh, nl = 300, 300, 2
# dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2
# tok_pad = vocab.index(PAD)
# model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)
# callback_funcs = [partial(GradientClipping, clip=0.1),
# partial(RNNTrainer, alpha=2., beta=1.)]
# learn = Learner(model, data, cross_entropy_flat, adam_opt(),
# metrics=accuracy_flat, lr=5e-3, callback_funcs=callback_funcs,
# model_name='pretrained_colab')
# learn.model.load_state_dict(torch.load(URLs.LOCAL_PATH/'pretrained_colab_final_1.pth'))
# learn.fit(1, callbacks=LRFinder())
# lr = 5e-3
# hyper_sched = sched_1cycle(lr, pct_start=0.3, mom_start=0.8, mom_mid=0.7, mom_end=0.8)
# learn.fit(3, callbacks=hyper_sched)
# torch.save(learn.model.state_dict(), URLs.LOCAL_PATH/'pretrained_colab_final_1.pth')
# learn.plotter.plot_train_stats()
# learn.plotter.plot_valid_stats()
# learn.plotter.plot_lr()