ゼロから作る Deep Learning 2/LSTM
Posted on
6章: ゲート付き RNN
ゼロから作るDeep Learning (2)の読書メモです。5章ではシンプルな RNN を使って言語モデルを実装しました。6章では LSTM などのゲートと呼ばれる仕組みを加えた RNN を実装することで長期的な依存関係を学習できるニューラルネットワークを実装していきます。
参考実装
%sh rm -rf /tmp/deep-learning-from-scratch-2 git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2 /tmp/deep-learning-from-scratch-2
Cloning into '/tmp/deep-learning-from-scratch-2'...
numpy と matplotlib のインストール
%sh pip3 install numpy matplotlib
Collecting numpy Downloading https://files.pythonhosted.org/packages/86/04/bd774106ae0ae1ada68c67efe89f1a16b2aa373cc2db15d974002a9f136d/numpy-1.15.4-cp35-cp35m-manylinux1_x86_64.whl (13.8MB) Collecting matplotlib Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB) Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/71/e8/6777f6624681c8b9701a8a0a5654f3eb56919a01a78e12bf3c73f5a3c714/pyparsing-2.3.0-py2.py3-none-any.whl (59kB) Collecting cycler>=0.10 (from matplotlib) Using cached https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl Collecting python-dateutil>=2.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB) Collecting kiwisolver>=1.0.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB) Collecting six (from cycler>=0.10->matplotlib) Downloading https://files.pythonhosted.org/packages/67/4b/141a581104b1f6397bfa78ac9d43d8ad29a7ca43ea90a2d863fe3056e86a/six-1.11.0-py2.py3-none-any.whl Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib) Installing collected packages: numpy, pyparsing, six, cycler, python-dateutil, kiwisolver, matplotlib Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.15.4 pyparsing-2.3.0 python-dateutil-2.7.5 six-1.11.0 You are using pip version 8.1.1, however version 18.1 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
%sh pip3 install seaborn
Collecting seaborn Downloading https://files.pythonhosted.org/packages/a8/76/220ba4420459d9c4c9c9587c6ce607bf56c25b3d3d2de62056efe482dadc/seaborn-0.9.0-py3-none-any.whl (208kB) Collecting scipy>=0.14.0 (from seaborn) Downloading https://files.pythonhosted.org/packages/cd/32/5196b64476bd41d596a8aba43506e2403e019c90e1a3dfc21d51b83db5a6/scipy-1.1.0-cp35-cp35m-manylinux1_x86_64.whl (33.1MB) Collecting pandas>=0.15.2 (from seaborn) Downloading https://files.pythonhosted.org/packages/5d/d4/6e9c56a561f1d27407bf29318ca43f36ccaa289271b805a30034eb3a8ec4/pandas-0.23.4-cp35-cp35m-manylinux1_x86_64.whl (8.7MB) Requirement already satisfied (use --upgrade to upgrade): matplotlib>=1.4.3 in /usr/local/lib/python3.5/dist-packages (from seaborn) Requirement already satisfied (use --upgrade to upgrade): numpy>=1.9.3 in /usr/local/lib/python3.5/dist-packages (from seaborn) Collecting pytz>=2011k (from pandas>=0.15.2->seaborn) Downloading https://files.pythonhosted.org/packages/f8/0e/2365ddc010afb3d79147f1dd544e5ee24bf4ece58ab99b16fbb465ce6dc0/pytz-2018.7-py2.py3-none-any.whl (506kB) Requirement already satisfied (use --upgrade to upgrade): python-dateutil>=2.5.0 in /usr/local/lib/python3.5/dist-packages (from pandas>=0.15.2->seaborn) Requirement already satisfied (use --upgrade to upgrade): pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn) Requirement already satisfied (use --upgrade to upgrade): kiwisolver>=1.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn) Requirement already satisfied (use --upgrade to upgrade): cycler>=0.10 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn) Requirement already satisfied (use --upgrade to upgrade): six>=1.5 in /usr/local/lib/python3.5/dist-packages (from python-dateutil>=2.5.0->pandas>=0.15.2->seaborn) Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib>=1.4.3->seaborn) Installing collected packages: scipy, pytz, pandas, seaborn Successfully installed pandas-0.23.4 pytz-2018.7 scipy-1.1.0 seaborn-0.9.0 You are using pip version 8.1.1, however version 18.1 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
6.1.3: 勾配消失もしくは勾配爆発の原因
%python3 import numpy as np import matplotlib.pyplot as plt import seaborn as sns def plot_tanh(x): y = np.tanh(x) plt.plot(x, y, label='tanh') def plot_dydx(x): y = 1 / (np.cosh(x)**2) plt.plot(x, y, '--', label='dy/dx') sns.set() x = np.arange(-5, 5, 0.2) plt.xlabel('x') plt.ylabel('y') plot_tanh(x) plot_dydx(x) plt.legend(loc='lower right', fontsize=18, borderaxespad=1)
<matplotlib.legend.Legend object at 0x7f8c13198ba8>
- \(y=tanh(x)\) を微分した関数は \(x\) が 0 から遠ざかるにつれて小さくなるので逆伝搬で tanh ノードを通るたびに出力される値は小さくなる
- ReLU を使うことで勾配消失を抑えることができるらしい
%python3 N = 2 H = 3 T = 20 dh = np.ones((N, H)) np.random.seed(3) Wh = np.random.randn(H, H) norm_list = [] for t in range(T): dh = np.dot(dh, Wh.T) norm = np.sqrt(np.sum(dh**2)) / N norm_list.append(norm) sns.set() plt.xlabel('time step') plt.ylabel('norm') plt.plot(norm_list)
[<matplotlib.lines.Line2D object at 0x7f8bfca6c5f8>]
%python3 N = 2 H = 3 T = 20 dh = np.ones((N, H)) np.random.seed(3) Wh = np.random.randn(H, H) * 0.5 norm_list = [] for t in range(T): dh = np.dot(dh, Wh.T) norm = np.sqrt(np.sum(dh**2)) / N norm_list.append(norm) sns.set() plt.xlabel('time step') plt.ylabel('norm') plt.plot(norm_list)
[<matplotlib.lines.Line2D object at 0x7f8c11092e80>]
- MatMul ノードについても同様に計算を繰り返すと勾配爆発や勾配消失が起こる
- 勾配爆発の対策
- 勾配クリッピング
- 勾配の L2 ノルムがしきい値を超えた場合に勾配を修正する
- 勾配消失の対策
- ゲートを導入する
ゲート
- 各要素に対して次の状態としてどれだけ重要かを調整するもの
- output ゲート
$$
o = σ(x_t W_{x}^{(o)} + h_{t-1}W_{h}^{(o)} + b^{(o)})
$$
- output ゲートと tanh ノードの各要素の積(アマダール積)を出力
- 各ゲートがそれぞれ重みを持っている
- 行列演算を行うときは各ゲートの演算をまとめることができる
LSTMレイヤー: 実装
%python3 import sys sys.path.append('/tmp/deep-learning-from-scratch-2') from common import config config.GPU = True from common.np import * from common.functions import sigmoid class LSTM: def __init__(self, Wx, Wh, b): self.params = [Wx, Wh, b] self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)] self.cache = None def forward(self, x, h_prev, c_prev): Wx, Wh, b = self.params N, H = h_prev.shape A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b f = A[:, :H] g = A[:, H:2*H] i = A[:, 2*H:3*H] o = A[:, 3*H:] f = sigmoid(f) g = np.tanh(g) i = sigmoid(i) o = sigmoid(o) c_next = f * c_prev + g * i h_next = o * np.tanh(c_next) self.cache = (x, h_prev, c_prev, i, f, g, o, c_next) return h_next, c_next def backward(self, dh_next, dc_next): Wx, Wh, b = self.params x, h_prev, c_prev, i, f, g, o, c_next = self.cache tanh_c_next = np.tanh(c_next) ds = dc_next + (dh_next * o) * (1 - tanh_c_next ** 2) dc_prev = ds * f di = ds * g df = ds * c_prev do = dh_next * tanh_c_next dg = ds * i di *= i * (1 - i) df *= f * (1 - f) do *= o * (1 - o) dg *= (1 - g ** 2) dA = np.hstack((df, dg, di, do)) dWh = np.dot(h_prev.T, dA) dWx = np.dot(x.T, dA) db = dA.sum(axis=0) self.grads[0][...] = dWx self.grads[1][...] = dWh self.grads[2][...] = db dx = np.dot(dA, Wx.T) dh_prev = np.dot(dA, Wh.T) return dx, dh_prev, dc_prev
[92m------------------------------------------------------------[0m [92mGPU Mode (cupy)[0m [92m------------------------------------------------------------[0m
TimeLSTMレイヤー: 実装
%python3 class TimeLSTM: def __init__(self, Wx, Wh, b, stateful=False): self.params = [Wx, Wh, b] self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)] self.layers = None self.h, self.c = None, None self.dh = None self.stateful = stateful def forward(self, xs): Wx, Wh, b = self.params N, T, D = xs.shape H = Wh.shape[0] self.layers = [] hs = np.empty((N, T, H), dtype='f') if not self.stateful or self.h is None: self.h = np.zeros((N, H), dtype='f') if not self.stateful or self.c is None: self.c = np.zeros((N, H), dtype='f') for t in range(T): layer = LSTM(*self.params) self.h, self.c = layer.forward(xs[:, t, :], self.h, self.c) hs[:, t, :] = self.h self.layers.append(layer) return hs def backward(self, dhs): Wx, Wh, b = self.params N, T, H = dhs.shape D = Wx.shape[0] dxs = np.empty((N, T, D), dtype='f') dh, dc = 0, 0 grads = [0, 0, 0] for t in reversed(range(T)): layer = self.layers[t] dx, dh, dc = layer.backward(dhs[:, t, :] + dh, dc) dxs[:, t, :] = dx for i, grad in enumerate(layer.grads): grads[i] += grad for i, grad in enumerate(grads): self.grads[i][...] = grad self.dh = dh return dxs def set_state(self, h, c=None): self.h, self.c = h, c def reset_state(self): self.h, self.c = None, None
RNNLM: 実装
%python3 import pickle from common.time_layers import TimeSoftmaxWithLoss, TimeEmbedding, TimeAffine class Rnnlm: def __init__(self, vocab_size=10000, wordvec_size=100, hidden_size=100): V, D, H = vocab_size, wordvec_size, hidden_size rn = np.random.randn embed_W = (rn(V, D) / 100).astype('f') lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f') lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f') lstm_b = np.zeros(4 * H).astype('f') affine_W = (rn(H, V) / np.sqrt(H)).astype('f') affine_b = np.zeros(V).astype('f') self.layers = [ TimeEmbedding(embed_W), TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=True), TimeAffine(affine_W, affine_b) ] self.loss_layer = TimeSoftmaxWithLoss() self.lstm_layer = self.layers[1] self.params, self.grads = [], [] for layer in self.layers: self.params += layer.params self.grads += layer.grads def predict(self, xs): for layer in self.layers: xs = layer.forward(xs) return xs def forward(self, xs, ts): score = self.predict(xs) loss = self.loss_layer.forward(score, ts) return loss def backward(self, dout=1): dout = self.loss_layer.backward(dout) for layer in reversed(self.layers): dout = layer.backward(dout) return dout def reset_state(self): self.lstm_layer.reset_state() def save_params(self, file_name='Rnnlm.pkl'): with open(file_name, 'wb') as f: pickle.dump(self.params, f) def load_params(self, file_name='Rnnlm.pkl'): with open(file_name, 'rb') as f: self.params = pickle.load(f)
CUDAを使う場合
%sh nvidia-smi
Thu Nov 15 13:21:21 2018 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 396.37 Driver Version: 396.37 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla K80 On | 00000000:00:1E.0 Off | 0 | | N/A 56C P8 28W / 149W | 0MiB / 11441MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
%sh nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2017 NVIDIA Corporation Built on Fri_Sep__1_21:08:03_CDT_2017 Cuda compilation tools, release 9.0, V9.0.176
%sh sudo pip-3.6 install cupy-cuda90
Collecting cupy-cuda90 Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB) Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90) Requirement already satisfied: fastrlock>=0.3 in /usr/local/lib64/python3.6/site-packages (from cupy-cuda90) Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90) Installing collected packages: cupy-cuda90 Successfully installed cupy-cuda90-5.0.0 You are using pip version 9.0.3, however version 18.1 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
トレーニング実行
%python3 from common.optimizer import SGD from common.trainer import RnnlmTrainer from common.util import eval_perplexity from dataset import ptb # hyper parameters batch_size = 20 wordvec_size = 100 hidden_size = 100 time_size = 35 lr = 20.0 max_epoch = 4 max_grad = 0.25 # load dataset corpus, word_to_id, id_to_word = ptb.load_data('train') corpus_test, _, _ = ptb.load_data('test') vocab_size = len(word_to_id) xs = corpus[:-1] ts = corpus[1:] # generate model model = Rnnlm(vocab_size, wordvec_size, hidden_size) optimizer = SGD(lr) trainer = RnnlmTrainer(model, optimizer) trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20) trainer.plot(ylim=(0, 500)) model.reset_state() ppl_test = eval_perplexity(model, corpus_test) print('test perplexity: ', ppl_test) model.save_params()
| epoch 1 | iter 1 / 1327 | time 12[s] | perplexity 10000.52 | epoch 1 | iter 21 / 1327 | time 18[s] | perplexity 3043.08 | epoch 1 | iter 41 / 1327 | time 24[s] | perplexity 1223.89 | epoch 1 | iter 61 / 1327 | time 30[s] | perplexity 988.30 | epoch 1 | iter 81 / 1327 | time 36[s] | perplexity 812.58 | epoch 1 | iter 101 / 1327 | time 42[s] | perplexity 685.93 | epoch 1 | iter 121 / 1327 | time 48[s] | perplexity 628.84 | epoch 1 | iter 141 / 1327 | time 54[s] | perplexity 607.25 | epoch 1 | iter 161 / 1327 | time 60[s] | perplexity 593.01 | epoch 1 | iter 181 / 1327 | time 66[s] | perplexity 572.24 | epoch 1 | iter 201 / 1327 | time 72[s] | perplexity 516.94 | epoch 1 | iter 221 / 1327 | time 78[s] | perplexity 500.29 | epoch 1 | iter 241 / 1327 | time 84[s] | perplexity 441.31 | epoch 1 | iter 261 / 1327 | time 90[s] | perplexity 466.61 | epoch 1 | iter 281 / 1327 | time 95[s] | perplexity 457.95 | epoch 1 | iter 301 / 1327 | time 101[s] | perplexity 400.09 | epoch 1 | iter 321 / 1327 | time 107[s] | perplexity 354.01 | epoch 1 | iter 341 / 1327 | time 113[s] | perplexity 414.76 | epoch 1 | iter 361 / 1327 | time 119[s] | perplexity 413.01 | epoch 1 | iter 381 / 1327 | time 125[s] | perplexity 338.24 | epoch 1 | iter 401 / 1327 | time 131[s] | perplexity 358.44 | epoch 1 | iter 421 / 1327 | time 137[s] | perplexity 351.17 | epoch 1 | iter 441 / 1327 | time 143[s] | perplexity 332.46 | epoch 1 | iter 461 / 1327 | time 149[s] | perplexity 332.15 | epoch 1 | iter 481 / 1327 | time 155[s] | perplexity 308.20 | epoch 1 | iter 501 / 1327 | time 161[s] | perplexity 315.23 | epoch 1 | iter 521 / 1327 | time 167[s] | perplexity 307.01 | epoch 1 | iter 541 / 1327 | time 173[s] | perplexity 321.42 | epoch 1 | iter 561 / 1327 | time 179[s] | perplexity 292.32 | epoch 1 | iter 581 / 1327 | time 184[s] | perplexity 264.95 | epoch 1 | iter 601 / 1327 | time 190[s] | perplexity 342.06 | epoch 1 | iter 621 / 1327 | time 196[s] | perplexity 317.83 | epoch 1 | iter 641 / 1327 | time 202[s] | perplexity 288.22 | epoch 1 | iter 661 / 1327 | time 208[s] | perplexity 273.88 | epoch 1 | iter 681 / 1327 | time 214[s] | perplexity 230.15 | epoch 1 | iter 701 / 1327 | time 220[s] | perplexity 253.12 | epoch 1 | iter 721 / 1327 | time 226[s] | perplexity 265.56 | epoch 1 | iter 741 / 1327 | time 232[s] | perplexity 224.49 | epoch 1 | iter 761 / 1327 | time 238[s] | perplexity 236.62 | epoch 1 | iter 781 / 1327 | time 244[s] | perplexity 224.24 | epoch 1 | iter 801 / 1327 | time 250[s] | perplexity 243.74 | epoch 1 | iter 821 / 1327 | time 256[s] | perplexity 227.72 | epoch 1 | iter 841 / 1327 | time 262[s] | perplexity 231.35 | epoch 1 | iter 861 / 1327 | time 267[s] | perplexity 223.62 | epoch 1 | iter 881 / 1327 | time 273[s] | perplexity 206.62 | epoch 1 | iter 901 / 1327 | time 279[s] | perplexity 256.72 | epoch 1 | iter 921 / 1327 | time 285[s] | perplexity 227.62 | epoch 1 | iter 941 / 1327 | time 291[s] | perplexity 232.06 | epoch 1 | iter 961 / 1327 | time 297[s] | perplexity 244.22 | epoch 1 | iter 981 / 1327 | time 303[s] | perplexity 230.51 | epoch 1 | iter 1001 / 1327 | time 309[s] | perplexity 194.96 | epoch 1 | iter 1021 / 1327 | time 315[s] | perplexity 226.91 | epoch 1 | iter 1041 / 1327 | time 321[s] | perplexity 209.66 | epoch 1 | iter 1061 / 1327 | time 327[s] | perplexity 199.37 | epoch 1 | iter 1081 / 1327 | time 333[s] | perplexity 170.13 | epoch 1 | iter 1101 / 1327 | time 339[s] | perplexity 191.02 | epoch 1 | iter 1121 / 1327 | time 345[s] | perplexity 231.76 | epoch 1 | iter 1141 / 1327 | time 351[s] | perplexity 209.41 | epoch 1 | iter 1161 / 1327 | time 356[s] | perplexity 200.44 | epoch 1 | iter 1181 / 1327 | time 362[s] | perplexity 192.77 | epoch 1 | iter 1201 / 1327 | time 368[s] | perplexity 165.20 | epoch 1 | iter 1221 / 1327 | time 374[s] | perplexity 160.68 | epoch 1 | iter 1241 / 1327 | time 380[s] | perplexity 187.82 | epoch 1 | iter 1261 / 1327 | time 386[s] | perplexity 173.45 | epoch 1 | iter 1281 / 1327 | time 392[s] | perplexity 179.21 | epoch 1 | iter 1301 / 1327 | time 398[s] | perplexity 224.94 | epoch 1 | iter 1321 / 1327 | time 404[s] | perplexity 211.44 | epoch 2 | iter 1 / 1327 | time 406[s] | perplexity 225.14 | epoch 2 | iter 21 / 1327 | time 412[s] | perplexity 204.59 | epoch 2 | iter 41 / 1327 | time 418[s] | perplexity 191.89 | epoch 2 | iter 61 / 1327 | time 424[s] | perplexity 178.07 | epoch 2 | iter 81 / 1327 | time 430[s] | perplexity 160.75 | epoch 2 | iter 101 / 1327 | time 436[s] | perplexity 153.10 | epoch 2 | iter 121 / 1327 | time 441[s] | perplexity 162.32 | epoch 2 | iter 141 / 1327 | time 447[s] | perplexity 179.64 | epoch 2 | iter 161 / 1327 | time 453[s] | perplexity 194.35 | epoch 2 | iter 181 / 1327 | time 459[s] | perplexity 202.43 | epoch 2 | iter 201 / 1327 | time 465[s] | perplexity 188.74 | epoch 2 | iter 221 / 1327 | time 471[s] | perplexity 185.50 | epoch 2 | iter 241 / 1327 | time 477[s] | perplexity 178.58 | epoch 2 | iter 261 / 1327 | time 483[s] | perplexity 186.30 | epoch 2 | iter 281 / 1327 | time 489[s] | perplexity 187.32 | epoch 2 | iter 301 / 1327 | time 495[s] | perplexity 167.80 | epoch 2 | iter 321 / 1327 | time 501[s] | perplexity 139.48 | epoch 2 | iter 341 / 1327 | time 507[s] | perplexity 171.67 | epoch 2 | iter 361 / 1327 | time 512[s] | perplexity 197.60 | epoch 2 | iter 381 / 1327 | time 518[s] | perplexity 154.82 | epoch 2 | iter 401 / 1327 | time 524[s] | perplexity 169.48 | epoch 2 | iter 421 / 1327 | time 530[s] | perplexity 155.43 | epoch 2 | iter 441 / 1327 | time 536[s] | perplexity 164.11 | epoch 2 | iter 461 / 1327 | time 542[s] | perplexity 158.15 | epoch 2 | iter 481 / 1327 | time 548[s] | perplexity 158.01 | epoch 2 | iter 501 / 1327 | time 554[s] | perplexity 169.49 | epoch 2 | iter 521 / 1327 | time 560[s] | perplexity 172.82 | epoch 2 | iter 541 / 1327 | time 566[s] | perplexity 175.41 | epoch 2 | iter 561 / 1327 | time 572[s] | perplexity 157.42 | epoch 2 | iter 581 / 1327 | time 578[s] | perplexity 140.03 | epoch 2 | iter 601 / 1327 | time 584[s] | perplexity 192.35 | epoch 2 | iter 621 / 1327 | time 589[s] | perplexity 182.63 | epoch 2 | iter 641 / 1327 | time 595[s] | perplexity 165.24 | epoch 2 | iter 661 / 1327 | time 601[s] | perplexity 154.49 | epoch 2 | iter 681 / 1327 | time 607[s] | perplexity 129.85 | epoch 2 | iter 701 / 1327 | time 613[s] | perplexity 153.00 | epoch 2 | iter 721 / 1327 | time 619[s] | perplexity 160.11 | epoch 2 | iter 741 / 1327 | time 625[s] | perplexity 134.15 | epoch 2 | iter 761 / 1327 | time 631[s] | perplexity 132.44 | epoch 2 | iter 781 / 1327 | time 637[s] | perplexity 135.70 | epoch 2 | iter 801 / 1327 | time 643[s] | perplexity 147.10 | epoch 2 | iter 821 / 1327 | time 649[s] | perplexity 145.80 | epoch 2 | iter 841 / 1327 | time 655[s] | perplexity 144.89 | epoch 2 | iter 861 / 1327 | time 660[s] | perplexity 146.55 | epoch 2 | iter 881 / 1327 | time 666[s] | perplexity 129.88 | epoch 2 | iter 901 / 1327 | time 672[s] | perplexity 167.04 | epoch 2 | iter 921 / 1327 | time 678[s] | perplexity 146.77 | epoch 2 | iter 941 / 1327 | time 684[s] | perplexity 152.44 | epoch 2 | iter 961 / 1327 | time 690[s] | perplexity 164.83 | epoch 2 | iter 981 / 1327 | time 696[s] | perplexity 153.86 | epoch 2 | iter 1001 / 1327 | time 702[s] | perplexity 132.51 | epoch 2 | iter 1021 / 1327 | time 708[s] | perplexity 155.63 | epoch 2 | iter 1041 / 1327 | time 714[s] | perplexity 143.63 | epoch 2 | iter 1061 / 1327 | time 720[s] | perplexity 130.54 | epoch 2 | iter 1081 / 1327 | time 726[s] | perplexity 111.60 | epoch 2 | iter 1101 / 1327 | time 732[s] | perplexity 120.41 | epoch 2 | iter 1121 / 1327 | time 738[s] | perplexity 153.38 | epoch 2 | iter 1141 / 1327 | time 743[s] | perplexity 143.03 | epoch 2 | iter 1161 / 1327 | time 749[s] | perplexity 132.85 | epoch 2 | iter 1181 / 1327 | time 755[s] | perplexity 135.65 | epoch 2 | iter 1201 / 1327 | time 761[s] | perplexity 113.97 | epoch 2 | iter 1221 / 1327 | time 767[s] | perplexity 108.89 | epoch 2 | iter 1241 / 1327 | time 773[s] | perplexity 130.19 | epoch 2 | iter 1261 / 1327 | time 779[s] | perplexity 125.34 | epoch 2 | iter 1281 / 1327 | time 785[s] | perplexity 122.60 | epoch 2 | iter 1301 / 1327 | time 791[s] | perplexity 159.44 | epoch 2 | iter 1321 / 1327 | time 797[s] | perplexity 152.86 | epoch 3 | iter 1 / 1327 | time 799[s] | perplexity 161.38 | epoch 3 | iter 21 / 1327 | time 805[s] | perplexity 145.38 | epoch 3 | iter 41 / 1327 | time 811[s] | perplexity 137.33 | epoch 3 | iter 61 / 1327 | time 817[s] | perplexity 128.55 | epoch 3 | iter 81 / 1327 | time 822[s] | perplexity 117.85 | epoch 3 | iter 101 / 1327 | time 828[s] | perplexity 105.79 | epoch 3 | iter 121 / 1327 | time 834[s] | perplexity 115.92 | epoch 3 | iter 141 / 1327 | time 840[s] | perplexity 127.47 | epoch 3 | iter 161 / 1327 | time 846[s] | perplexity 143.02 | epoch 3 | iter 181 / 1327 | time 852[s] | perplexity 152.12 | epoch 3 | iter 201 / 1327 | time 858[s] | perplexity 141.67 | epoch 3 | iter 221 / 1327 | time 864[s] | perplexity 141.38 | epoch 3 | iter 241 / 1327 | time 870[s] | perplexity 135.04 | epoch 3 | iter 261 / 1327 | time 876[s] | perplexity 139.20 | epoch 3 | iter 281 / 1327 | time 881[s] | perplexity 142.71 | epoch 3 | iter 301 / 1327 | time 887[s] | perplexity 125.64 | epoch 3 | iter 321 / 1327 | time 893[s] | perplexity 102.47 | epoch 3 | iter 341 / 1327 | time 899[s] | perplexity 125.40 | epoch 3 | iter 361 / 1327 | time 905[s] | perplexity 154.14 | epoch 3 | iter 381 / 1327 | time 911[s] | perplexity 115.68 | epoch 3 | iter 401 / 1327 | time 917[s] | perplexity 130.53 | epoch 3 | iter 421 / 1327 | time 923[s] | perplexity 113.71 | epoch 3 | iter 441 / 1327 | time 929[s] | perplexity 123.27 | epoch 3 | iter 461 / 1327 | time 935[s] | perplexity 117.73 | epoch 3 | iter 481 / 1327 | time 941[s] | perplexity 119.94 | epoch 3 | iter 501 / 1327 | time 946[s] | perplexity 128.34 | epoch 3 | iter 521 / 1327 | time 952[s] | perplexity 137.33 | epoch 3 | iter 541 / 1327 | time 958[s] | perplexity 136.05 | epoch 3 | iter 561 / 1327 | time 964[s] | perplexity 119.85 | epoch 3 | iter 581 / 1327 | time 970[s] | perplexity 106.57 | epoch 3 | iter 601 / 1327 | time 976[s] | perplexity 152.18 | epoch 3 | iter 621 / 1327 | time 982[s] | perplexity 143.81 | epoch 3 | iter 641 / 1327 | time 988[s] | perplexity 128.85 | epoch 3 | iter 661 / 1327 | time 994[s] | perplexity 121.50 | epoch 3 | iter 681 / 1327 | time 1000[s] | perplexity 99.55 | epoch 3 | iter 701 / 1327 | time 1006[s] | perplexity 120.48 | epoch 3 | iter 721 / 1327 | time 1011[s] | perplexity 126.11 | epoch 3 | iter 741 / 1327 | time 1017[s] | perplexity 107.88 | epoch 3 | iter 761 / 1327 | time 1023[s] | perplexity 103.60 | epoch 3 | iter 781 / 1327 | time 1029[s] | perplexity 103.77 | epoch 3 | iter 801 / 1327 | time 1035[s] | perplexity 114.90 | epoch 3 | iter 821 / 1327 | time 1041[s] | perplexity 116.49 | epoch 3 | iter 841 / 1327 | time 1047[s] | perplexity 114.93 | epoch 3 | iter 861 / 1327 | time 1053[s] | perplexity 119.74 | epoch 3 | iter 881 / 1327 | time 1059[s] | perplexity 106.93 | epoch 3 | iter 901 / 1327 | time 1065[s] | perplexity 133.25 | epoch 3 | iter 921 / 1327 | time 1071[s] | perplexity 118.78 | epoch 3 | iter 941 / 1327 | time 1076[s] | perplexity 126.31 | epoch 3 | iter 961 / 1327 | time 1082[s] | perplexity 133.07 | epoch 3 | iter 981 / 1327 | time 1088[s] | perplexity 123.01 | epoch 3 | iter 1001 / 1327 | time 1094[s] | perplexity 110.23 | epoch 3 | iter 1021 / 1327 | time 1100[s] | perplexity 129.15 | epoch 3 | iter 1041 / 1327 | time 1106[s] | perplexity 120.18 | epoch 3 | iter 1061 / 1327 | time 1112[s] | perplexity 104.30 | epoch 3 | iter 1081 / 1327 | time 1118[s] | perplexity 88.97 | epoch 3 | iter 1101 / 1327 | time 1124[s] | perplexity 95.64 | epoch 3 | iter 1121 / 1327 | time 1130[s] | perplexity 120.60 | epoch 3 | iter 1141 / 1327 | time 1136[s] | perplexity 115.53 | epoch 3 | iter 1161 / 1327 | time 1142[s] | perplexity 106.47 | epoch 3 | iter 1181 / 1327 | time 1148[s] | perplexity 112.18 | epoch 3 | iter 1201 / 1327 | time 1154[s] | perplexity 94.87 | epoch 3 | iter 1221 / 1327 | time 1159[s] | perplexity 87.82 | epoch 3 | iter 1241 / 1327 | time 1165[s] | perplexity 105.57 | epoch 3 | iter 1261 / 1327 | time 1171[s] | perplexity 105.54 | epoch 3 | iter 1281 / 1327 | time 1177[s] | perplexity 100.72 | epoch 3 | iter 1301 / 1327 | time 1183[s] | perplexity 131.29 | epoch 3 | iter 1321 / 1327 | time 1189[s] | perplexity 126.84 | epoch 4 | iter 1 / 1327 | time 1191[s] | perplexity 134.18 | epoch 4 | iter 21 / 1327 | time 1197[s] | perplexity 122.94 | epoch 4 | iter 41 / 1327 | time 1203[s] | perplexity 107.72 | epoch 4 | iter 61 / 1327 | time 1209[s] | perplexity 109.87 | epoch 4 | iter 81 / 1327 | time 1215[s] | perplexity 97.44 | epoch 4 | iter 101 / 1327 | time 1221[s] | perplexity 86.77 | epoch 4 | iter 121 / 1327 | time 1227[s] | perplexity 95.40 | epoch 4 | iter 141 / 1327 | time 1233[s] | perplexity 104.29 | epoch 4 | iter 161 / 1327 | time 1239[s] | perplexity 119.42 | epoch 4 | iter 181 / 1327 | time 1244[s] | perplexity 129.87 | epoch 4 | iter 201 / 1327 | time 1250[s] | perplexity 120.89 | epoch 4 | iter 221 / 1327 | time 1256[s] | perplexity 121.37 | epoch 4 | iter 241 / 1327 | time 1262[s] | perplexity 114.78 | epoch 4 | iter 261 / 1327 | time 1268[s] | perplexity 114.97 | epoch 4 | iter 281 / 1327 | time 1274[s] | perplexity 122.07 | epoch 4 | iter 301 / 1327 | time 1280[s] | perplexity 105.88 | epoch 4 | iter 321 / 1327 | time 1286[s] | perplexity 84.09 | epoch 4 | iter 341 / 1327 | time 1292[s] | perplexity 101.45 | epoch 4 | iter 361 / 1327 | time 1298[s] | perplexity 131.19 | epoch 4 | iter 381 / 1327 | time 1304[s] | perplexity 98.35 | epoch 4 | iter 401 / 1327 | time 1310[s] | perplexity 111.70 | epoch 4 | iter 421 / 1327 | time 1316[s] | perplexity 94.27 | epoch 4 | iter 441 / 1327 | time 1322[s] | perplexity 103.22 | epoch 4 | iter 461 / 1327 | time 1328[s] | perplexity 100.04 | epoch 4 | iter 481 / 1327 | time 1333[s] | perplexity 103.83 | epoch 4 | iter 501 / 1327 | time 1339[s] | perplexity 107.87 | epoch 4 | iter 521 / 1327 | time 1345[s] | perplexity 117.03 | epoch 4 | iter 541 / 1327 | time 1351[s] | perplexity 114.25 | epoch 4 | iter 561 / 1327 | time 1357[s] | perplexity 103.34 | epoch 4 | iter 581 / 1327 | time 1363[s] | perplexity 89.92 | epoch 4 | iter 601 / 1327 | time 1369[s] | perplexity 129.68 | epoch 4 | iter 621 / 1327 | time 1375[s] | perplexity 123.08 | epoch 4 | iter 641 / 1327 | time 1381[s] | perplexity 111.31 | epoch 4 | iter 661 / 1327 | time 1387[s] | perplexity 104.01 | epoch 4 | iter 681 / 1327 | time 1393[s] | perplexity 84.54 | epoch 4 | iter 701 / 1327 | time 1399[s] | perplexity 103.82 | epoch 4 | iter 721 / 1327 | time 1404[s] | perplexity 108.22 | epoch 4 | iter 741 / 1327 | time 1410[s] | perplexity 96.01 | epoch 4 | iter 761 / 1327 | time 1416[s] | perplexity 88.78 | epoch 4 | iter 781 / 1327 | time 1422[s] | perplexity 88.47 | epoch 4 | iter 801 / 1327 | time 1428[s] | perplexity 97.88 | epoch 4 | iter 821 / 1327 | time 1434[s] | perplexity 103.32 | epoch 4 | iter 841 / 1327 | time 1440[s] | perplexity 99.11 | epoch 4 | iter 861 / 1327 | time 1446[s] | perplexity 104.41 | epoch 4 | iter 881 / 1327 | time 1452[s] | perplexity 91.42 | epoch 4 | iter 901 / 1327 | time 1458[s] | perplexity 115.78 | epoch 4 | iter 921 / 1327 | time 1464[s] | perplexity 103.37 | epoch 4 | iter 941 / 1327 | time 1470[s] | perplexity 111.49 | epoch 4 | iter 961 / 1327 | time 1476[s] | perplexity 112.74 | epoch 4 | iter 981 / 1327 | time 1481[s] | perplexity 106.57 | epoch 4 | iter 1001 / 1327 | time 1487[s] | perplexity 98.34 | epoch 4 | iter 1021 / 1327 | time 1493[s] | perplexity 114.28 | epoch 4 | iter 1041 / 1327 | time 1499[s] | perplexity 104.62 | epoch 4 | iter 1061 / 1327 | time 1505[s] | perplexity 90.57 | epoch 4 | iter 1081 / 1327 | time 1511[s] | perplexity 78.93 | epoch 4 | iter 1101 / 1327 | time 1517[s] | perplexity 79.35 | epoch 4 | iter 1121 / 1327 | time 1523[s] | perplexity 103.15 | epoch 4 | iter 1141 / 1327 | time 1529[s] | perplexity 100.01 | epoch 4 | iter 1161 / 1327 | time 1535[s] | perplexity 92.35 | epoch 4 | iter 1181 / 1327 | time 1541[s] | perplexity 96.91 | epoch 4 | iter 1201 / 1327 | time 1547[s] | perplexity 84.74 | epoch 4 | iter 1221 / 1327 | time 1553[s] | perplexity 75.15 | epoch 4 | iter 1241 / 1327 | time 1558[s] | perplexity 91.74 | epoch 4 | iter 1261 / 1327 | time 1564[s] | perplexity 94.50 | epoch 4 | iter 1281 / 1327 | time 1570[s] | perplexity 89.46 | epoch 4 | iter 1301 / 1327 | time 1576[s] | perplexity 111.59 | epoch 4 | iter 1321 / 1327 | time 1582[s] | perplexity 110.69
8コアくらいで実行してみる
%python3 from common.optimizer import SGD from common.trainer import RnnlmTrainer from common.util import eval_perplexity from dataset import ptb # hyper parameters batch_size = 20 wordvec_size = 100 hidden_size = 100 time_size = 35 lr = 20.0 max_epoch = 4 max_grad = 0.25 # load dataset corpus, word_to_id, id_to_word = ptb.load_data('train') corpus_test, _, _ = ptb.load_data('test') vocab_size = len(word_to_id) xs = corpus[:-1] ts = corpus[1:] # generate model model = Rnnlm(vocab_size, wordvec_size, hidden_size) optimizer = SGD(lr) trainer = RnnlmTrainer(model, optimizer) trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20) trainer.plot(ylim=(0, 500)) model.reset_state() ppl_test = eval_perplexity(model, corpus_test) print('test perplexity: ', ppl_test) model.save_params()
Downloading ptb.train.txt ... Done Downloading ptb.test.txt ... Done | epoch 1 | iter 1 / 1327 | time 0[s] | perplexity 10000.10 | epoch 1 | iter 21 / 1327 | time 7[s] | perplexity 2852.60 | epoch 1 | iter 41 / 1327 | time 13[s] | perplexity 1197.79 | epoch 1 | iter 61 / 1327 | time 20[s] | perplexity 972.65 | epoch 1 | iter 81 / 1327 | time 26[s] | perplexity 801.79 | epoch 1 | iter 101 / 1327 | time 33[s] | perplexity 654.38 | epoch 1 | iter 121 / 1327 | time 40[s] | perplexity 650.02 | epoch 1 | iter 141 / 1327 | time 46[s] | perplexity 586.48 | epoch 1 | iter 161 / 1327 | time 53[s] | perplexity 575.72 | epoch 1 | iter 181 / 1327 | time 60[s] | perplexity 575.51 | epoch 1 | iter 201 / 1327 | time 66[s] | perplexity 500.62 | epoch 1 | iter 221 / 1327 | time 73[s] | perplexity 489.72 | epoch 1 | iter 241 / 1327 | time 79[s] | perplexity 440.72 | epoch 1 | iter 261 / 1327 | time 86[s] | perplexity 464.75 | epoch 1 | iter 281 / 1327 | time 93[s] | perplexity 458.81 | epoch 1 | iter 301 / 1327 | time 99[s] | perplexity 391.53 | epoch 1 | iter 321 / 1327 | time 106[s] | perplexity 352.07 | epoch 1 | iter 341 / 1327 | time 112[s] | perplexity 402.92 | epoch 1 | iter 361 / 1327 | time 119[s] | perplexity 408.88 | epoch 1 | iter 381 / 1327 | time 125[s] | perplexity 334.41 | epoch 1 | iter 401 / 1327 | time 132[s] | perplexity 356.17 | epoch 1 | iter 421 / 1327 | time 139[s] | perplexity 346.59 | epoch 1 | iter 441 / 1327 | time 145[s] | perplexity 325.49 | epoch 1 | iter 461 / 1327 | time 152[s] | perplexity 329.53 | epoch 1 | iter 481 / 1327 | time 159[s] | perplexity 312.91 | epoch 1 | iter 501 / 1327 | time 166[s] | perplexity 318.56 | epoch 1 | iter 521 / 1327 | time 172[s] | perplexity 304.97 | epoch 1 | iter 541 / 1327 | time 179[s] | perplexity 315.58 | epoch 1 | iter 561 / 1327 | time 186[s] | perplexity 285.87 | epoch 1 | iter 581 / 1327 | time 192[s] | perplexity 262.11 | epoch 1 | iter 601 / 1327 | time 199[s] | perplexity 336.64 | epoch 1 | iter 621 / 1327 | time 206[s] | perplexity 312.01 | epoch 1 | iter 641 / 1327 | time 212[s] | perplexity 286.62 | epoch 1 | iter 661 / 1327 | time 219[s] | perplexity 273.06 | epoch 1 | iter 681 / 1327 | time 226[s] | perplexity 227.11 | epoch 1 | iter 701 / 1327 | time 232[s] | perplexity 250.40 | epoch 1 | iter 721 / 1327 | time 239[s] | perplexity 262.76 | epoch 1 | iter 741 / 1327 | time 246[s] | perplexity 223.73 | epoch 1 | iter 761 / 1327 | time 252[s] | perplexity 232.59 | epoch 1 | iter 781 / 1327 | time 259[s] | perplexity 220.59 | epoch 1 | iter 801 / 1327 | time 266[s] | perplexity 244.55 | epoch 1 | iter 821 / 1327 | time 272[s] | perplexity 226.35 | epoch 1 | iter 841 / 1327 | time 279[s] | perplexity 229.13 | epoch 1 | iter 861 / 1327 | time 286[s] | perplexity 221.18 | epoch 1 | iter 881 / 1327 | time 293[s] | perplexity 208.49 | epoch 1 | iter 901 / 1327 | time 299[s] | perplexity 253.19 | epoch 1 | iter 921 / 1327 | time 306[s] | perplexity 230.20 | epoch 1 | iter 941 / 1327 | time 313[s] | perplexity 231.52 | epoch 1 | iter 961 / 1327 | time 319[s] | perplexity 244.30 | epoch 1 | iter 981 / 1327 | time 326[s] | perplexity 230.86 | epoch 1 | iter 1001 / 1327 | time 332[s] | perplexity 194.41 | epoch 1 | iter 1021 / 1327 | time 339[s] | perplexity 226.75 | epoch 1 | iter 1041 / 1327 | time 346[s] | perplexity 208.89 | epoch 1 | iter 1061 / 1327 | time 352[s] | perplexity 197.86 | epoch 1 | iter 1081 / 1327 | time 359[s] | perplexity 170.05 | epoch 1 | iter 1101 / 1327 | time 366[s] | perplexity 194.50 | epoch 1 | iter 1121 / 1327 | time 372[s] | perplexity 228.69 | epoch 1 | iter 1141 / 1327 | time 379[s] | perplexity 209.67 | epoch 1 | iter 1161 / 1327 | time 386[s] | perplexity 200.76 | epoch 1 | iter 1181 / 1327 | time 392[s] | perplexity 192.34 | epoch 1 | iter 1201 / 1327 | time 399[s] | perplexity 165.02 | epoch 1 | iter 1221 / 1327 | time 406[s] | perplexity 159.78 | epoch 1 | iter 1241 / 1327 | time 412[s] | perplexity 189.42 | epoch 1 | iter 1261 / 1327 | time 419[s] | perplexity 172.00 | epoch 1 | iter 1281 / 1327 | time 426[s] | perplexity 180.13 | epoch 1 | iter 1301 / 1327 | time 433[s] | perplexity 224.04 | epoch 1 | iter 1321 / 1327 | time 439[s] | perplexity 212.56 | epoch 2 | iter 1 / 1327 | time 442[s] | perplexity 227.27 | epoch 2 | iter 21 / 1327 | time 448[s] | perplexity 205.48 | epoch 2 | iter 41 / 1327 | time 455[s] | perplexity 192.19 | epoch 2 | iter 61 / 1327 | time 462[s] | perplexity 178.33 | epoch 2 | iter 81 / 1327 | time 469[s] | perplexity 160.51 | epoch 2 | iter 101 / 1327 | time 475[s] | perplexity 153.43 | epoch 2 | iter 121 / 1327 | time 482[s] | perplexity 161.21 | epoch 2 | iter 141 / 1327 | time 489[s] | perplexity 179.45 | epoch 2 | iter 161 / 1327 | time 496[s] | perplexity 194.39 | epoch 2 | iter 181 / 1327 | time 502[s] | perplexity 203.20 | epoch 2 | iter 201 / 1327 | time 509[s] | perplexity 187.49 | epoch 2 | iter 221 / 1327 | time 516[s] | perplexity 184.57 | epoch 2 | iter 241 / 1327 | time 522[s] | perplexity 179.09 | epoch 2 | iter 261 / 1327 | time 529[s] | perplexity 187.89 | epoch 2 | iter 281 / 1327 | time 536[s] | perplexity 186.25 | epoch 2 | iter 301 / 1327 | time 543[s] | perplexity 167.91 | epoch 2 | iter 321 / 1327 | time 549[s] | perplexity 138.67 | epoch 2 | iter 341 / 1327 | time 556[s] | perplexity 172.03 | epoch 2 | iter 361 / 1327 | time 563[s] | perplexity 198.36 | epoch 2 | iter 381 / 1327 | time 570[s] | perplexity 155.07 | epoch 2 | iter 401 / 1327 | time 576[s] | perplexity 169.08 | epoch 2 | iter 421 / 1327 | time 583[s] | perplexity 158.41 | epoch 2 | iter 441 / 1327 | time 590[s] | perplexity 165.13 | epoch 2 | iter 461 / 1327 | time 596[s] | perplexity 159.82 | epoch 2 | iter 481 / 1327 | time 603[s] | perplexity 158.04 | epoch 2 | iter 501 / 1327 | time 610[s] | perplexity 171.63 | epoch 2 | iter 521 / 1327 | time 616[s] | perplexity 173.55 | epoch 2 | iter 541 / 1327 | time 623[s] | perplexity 176.24 | epoch 2 | iter 561 / 1327 | time 630[s] | perplexity 156.97 | epoch 2 | iter 581 / 1327 | time 637[s] | perplexity 139.85 | epoch 2 | iter 601 / 1327 | time 643[s] | perplexity 192.27 | epoch 2 | iter 621 / 1327 | time 651[s] | perplexity 183.50 | epoch 2 | iter 641 / 1327 | time 658[s] | perplexity 167.28 | epoch 2 | iter 661 / 1327 | time 664[s] | perplexity 154.80 | epoch 2 | iter 681 / 1327 | time 671[s] | perplexity 131.82 | epoch 2 | iter 701 / 1327 | time 678[s] | perplexity 152.01 | epoch 2 | iter 721 / 1327 | time 685[s] | perplexity 161.30 | epoch 2 | iter 741 / 1327 | time 692[s] | perplexity 136.03 | epoch 2 | iter 761 / 1327 | time 698[s] | perplexity 131.99 | epoch 2 | iter 781 / 1327 | time 705[s] | perplexity 137.36 | epoch 2 | iter 801 / 1327 | time 712[s] | perplexity 148.77 | epoch 2 | iter 821 / 1327 | time 718[s] | perplexity 146.65 | epoch 2 | iter 841 / 1327 | time 725[s] | perplexity 145.19 | epoch 2 | iter 861 / 1327 | time 732[s] | perplexity 147.84 | epoch 2 | iter 881 / 1327 | time 739[s] | perplexity 131.78 | epoch 2 | iter 901 / 1327 | time 745[s] | perplexity 167.73 | epoch 2 | iter 921 / 1327 | time 752[s] | perplexity 148.08 | epoch 2 | iter 941 / 1327 | time 759[s] | perplexity 156.33 | epoch 2 | iter 961 / 1327 | time 765[s] | perplexity 166.75 | epoch 2 | iter 981 / 1327 | time 772[s] | perplexity 155.06 | epoch 2 | iter 1001 / 1327 | time 779[s] | perplexity 132.72 | epoch 2 | iter 1021 / 1327 | time 785[s] | perplexity 156.88 | epoch 2 | iter 1041 / 1327 | time 792[s] | perplexity 143.86 | epoch 2 | iter 1061 / 1327 | time 799[s] | perplexity 131.46 | epoch 2 | iter 1081 / 1327 | time 805[s] | perplexity 112.57 | epoch 2 | iter 1101 / 1327 | time 812[s] | perplexity 122.22 | epoch 2 | iter 1121 / 1327 | time 819[s] | perplexity 156.04 | epoch 2 | iter 1141 / 1327 | time 826[s] | perplexity 143.46 | epoch 2 | iter 1161 / 1327 | time 833[s] | perplexity 135.22 | epoch 2 | iter 1181 / 1327 | time 840[s] | perplexity 136.76 | epoch 2 | iter 1201 / 1327 | time 846[s] | perplexity 114.39 | epoch 2 | iter 1221 / 1327 | time 853[s] | perplexity 110.84 | epoch 2 | iter 1241 / 1327 | time 860[s] | perplexity 131.14 | epoch 2 | iter 1261 / 1327 | time 867[s] | perplexity 124.96 | epoch 2 | iter 1281 / 1327 | time 874[s] | perplexity 124.53 | epoch 2 | iter 1301 / 1327 | time 881[s] | perplexity 159.41 | epoch 2 | iter 1321 / 1327 | time 887[s] | perplexity 155.05 | epoch 3 | iter 1 / 1327 | time 890[s] | perplexity 162.31 | epoch 3 | iter 21 / 1327 | time 897[s] | perplexity 145.83 | epoch 3 | iter 41 / 1327 | time 904[s] | perplexity 137.93 | epoch 3 | iter 61 / 1327 | time 910[s] | perplexity 129.50 | epoch 3 | iter 81 / 1327 | time 917[s] | perplexity 118.11 | epoch 3 | iter 101 / 1327 | time 924[s] | perplexity 106.62 | epoch 3 | iter 121 / 1327 | time 931[s] | perplexity 116.58 | epoch 3 | iter 141 / 1327 | time 937[s] | perplexity 128.34 | epoch 3 | iter 161 / 1327 | time 944[s] | perplexity 145.42 | epoch 3 | iter 181 / 1327 | time 951[s] | perplexity 152.55 | epoch 3 | iter 201 / 1327 | time 958[s] | perplexity 143.41 | epoch 3 | iter 221 / 1327 | time 964[s] | perplexity 142.29 | epoch 3 | iter 241 / 1327 | time 971[s] | perplexity 136.47 | epoch 3 | iter 261 / 1327 | time 978[s] | perplexity 141.66 | epoch 3 | iter 281 / 1327 | time 984[s] | perplexity 142.91 | epoch 3 | iter 301 / 1327 | time 991[s] | perplexity 125.87 | epoch 3 | iter 321 / 1327 | time 998[s] | perplexity 104.19 | epoch 3 | iter 341 / 1327 | time 1004[s] | perplexity 128.50 | epoch 3 | iter 361 / 1327 | time 1011[s] | perplexity 154.88 | epoch 3 | iter 381 / 1327 | time 1018[s] | perplexity 117.23 | epoch 3 | iter 401 / 1327 | time 1025[s] | perplexity 130.37 | epoch 3 | iter 421 / 1327 | time 1031[s] | perplexity 116.28 | epoch 3 | iter 441 / 1327 | time 1038[s] | perplexity 126.29 | epoch 3 | iter 461 / 1327 | time 1045[s] | perplexity 120.17 | epoch 3 | iter 481 / 1327 | time 1051[s] | perplexity 121.45 | epoch 3 | iter 501 / 1327 | time 1058[s] | perplexity 132.70 | epoch 3 | iter 521 / 1327 | time 1065[s] | perplexity 138.81 | epoch 3 | iter 541 / 1327 | time 1072[s] | perplexity 139.54 | epoch 3 | iter 561 / 1327 | time 1078[s] | perplexity 120.78 | epoch 3 | iter 581 / 1327 | time 1085[s] | perplexity 106.74 | epoch 3 | iter 601 / 1327 | time 1092[s] | perplexity 151.46 | epoch 3 | iter 621 / 1327 | time 1099[s] | perplexity 146.41 | epoch 3 | iter 641 / 1327 | time 1105[s] | perplexity 132.30 | epoch 3 | iter 661 / 1327 | time 1112[s] | perplexity 120.47 | epoch 3 | iter 681 / 1327 | time 1119[s] | perplexity 101.10 | epoch 3 | iter 701 / 1327 | time 1125[s] | perplexity 120.23 | epoch 3 | iter 721 / 1327 | time 1132[s] | perplexity 127.58 | epoch 3 | iter 741 / 1327 | time 1139[s] | perplexity 110.48 | epoch 3 | iter 761 / 1327 | time 1146[s] | perplexity 104.48 | epoch 3 | iter 781 / 1327 | time 1152[s] | perplexity 106.76 | epoch 3 | iter 801 / 1327 | time 1159[s] | perplexity 116.66 | epoch 3 | iter 821 / 1327 | time 1166[s] | perplexity 120.04 | epoch 3 | iter 841 / 1327 | time 1173[s] | perplexity 115.68 | epoch 3 | iter 861 / 1327 | time 1179[s] | perplexity 122.05 | epoch 3 | iter 881 / 1327 | time 1186[s] | perplexity 108.33 | epoch 3 | iter 901 / 1327 | time 1193[s] | perplexity 134.20 | epoch 3 | iter 921 / 1327 | time 1200[s] | perplexity 120.01 | epoch 3 | iter 941 / 1327 | time 1206[s] | perplexity 130.07 | epoch 3 | iter 961 / 1327 | time 1213[s] | perplexity 133.97 | epoch 3 | iter 981 / 1327 | time 1220[s] | perplexity 125.72 | epoch 3 | iter 1001 / 1327 | time 1227[s] | perplexity 110.70 | epoch 3 | iter 1021 / 1327 | time 1233[s] | perplexity 128.98 | epoch 3 | iter 1041 / 1327 | time 1240[s] | perplexity 121.18 | epoch 3 | iter 1061 / 1327 | time 1247[s] | perplexity 104.08 | epoch 3 | iter 1081 / 1327 | time 1254[s] | perplexity 90.11 | epoch 3 | iter 1101 / 1327 | time 1260[s] | perplexity 96.71 | epoch 3 | iter 1121 / 1327 | time 1267[s] | perplexity 123.43 | epoch 3 | iter 1141 / 1327 | time 1274[s] | perplexity 116.54 | epoch 3 | iter 1161 / 1327 | time 1281[s] | perplexity 107.55 | epoch 3 | iter 1181 / 1327 | time 1287[s] | perplexity 113.06 | epoch 3 | iter 1201 / 1327 | time 1294[s] | perplexity 95.29 | epoch 3 | iter 1221 / 1327 | time 1301[s] | perplexity 89.38 | epoch 3 | iter 1241 / 1327 | time 1308[s] | perplexity 106.44 | epoch 3 | iter 1261 / 1327 | time 1314[s] | perplexity 106.48 | epoch 3 | iter 1281 / 1327 | time 1321[s] | perplexity 101.53 | epoch 3 | iter 1301 / 1327 | time 1328[s] | perplexity 130.48 | epoch 3 | iter 1321 / 1327 | time 1335[s] | perplexity 129.24 | epoch 4 | iter 1 / 1327 | time 1337[s] | perplexity 137.08 | epoch 4 | iter 21 / 1327 | time 1344[s] | perplexity 123.39 | epoch 4 | iter 41 / 1327 | time 1350[s] | perplexity 109.46 | epoch 4 | iter 61 / 1327 | time 1357[s] | perplexity 109.92 | epoch 4 | iter 81 / 1327 | time 1364[s] | perplexity 98.42 | epoch 4 | iter 101 / 1327 | time 1371[s] | perplexity 87.49 | epoch 4 | iter 121 / 1327 | time 1377[s] | perplexity 96.47 | epoch 4 | iter 141 / 1327 | time 1384[s] | perplexity 104.20 | epoch 4 | iter 161 / 1327 | time 1391[s] | perplexity 120.35 | epoch 4 | iter 181 / 1327 | time 1397[s] | perplexity 131.39 | epoch 4 | iter 201 / 1327 | time 1404[s] | perplexity 122.30 | epoch 4 | iter 221 / 1327 | time 1411[s] | perplexity 124.78 | epoch 4 | iter 241 / 1327 | time 1418[s] | perplexity 117.55 | epoch 4 | iter 261 / 1327 | time 1424[s] | perplexity 116.76 | epoch 4 | iter 281 / 1327 | time 1431[s] | perplexity 122.54 | epoch 4 | iter 301 / 1327 | time 1438[s] | perplexity 105.78 | epoch 4 | iter 321 / 1327 | time 1444[s] | perplexity 85.61 | epoch 4 | iter 341 / 1327 | time 1451[s] | perplexity 103.02 | epoch 4 | iter 361 / 1327 | time 1458[s] | perplexity 128.83 | epoch 4 | iter 381 / 1327 | time 1464[s] | perplexity 99.55 | epoch 4 | iter 401 / 1327 | time 1471[s] | perplexity 111.82 | epoch 4 | iter 421 / 1327 | time 1478[s] | perplexity 95.68 | epoch 4 | iter 441 / 1327 | time 1485[s] | perplexity 104.78 | epoch 4 | iter 461 / 1327 | time 1492[s] | perplexity 102.15 | epoch 4 | iter 481 / 1327 | time 1499[s] | perplexity 104.22 | epoch 4 | iter 501 / 1327 | time 1506[s] | perplexity 111.26 | epoch 4 | iter 521 / 1327 | time 1512[s] | perplexity 119.01 | epoch 4 | iter 541 / 1327 | time 1519[s] | perplexity 115.68 | epoch 4 | iter 561 / 1327 | time 1526[s] | perplexity 105.81 | epoch 4 | iter 581 / 1327 | time 1533[s] | perplexity 89.73 | epoch 4 | iter 601 / 1327 | time 1540[s] | perplexity 127.60 | epoch 4 | iter 621 / 1327 | time 1547[s] | perplexity 124.38 | epoch 4 | iter 641 / 1327 | time 1554[s] | perplexity 112.94 | epoch 4 | iter 661 / 1327 | time 1560[s] | perplexity 104.02 | epoch 4 | iter 681 / 1327 | time 1567[s] | perplexity 85.23 | epoch 4 | iter 701 / 1327 | time 1574[s] | perplexity 104.08 | epoch 4 | iter 721 / 1327 | time 1581[s] | perplexity 108.60 | epoch 4 | iter 741 / 1327 | time 1587[s] | perplexity 98.33 | epoch 4 | iter 761 / 1327 | time 1594[s] | perplexity 89.41 | epoch 4 | iter 781 / 1327 | time 1601[s] | perplexity 88.64 | epoch 4 | iter 801 / 1327 | time 1608[s] | perplexity 99.77 | epoch 4 | iter 821 / 1327 | time 1614[s] | perplexity 105.46 | epoch 4 | iter 841 / 1327 | time 1621[s] | perplexity 100.29 | epoch 4 | iter 861 / 1327 | time 1628[s] | perplexity 106.53 | epoch 4 | iter 881 / 1327 | time 1634[s] | perplexity 92.96 | epoch 4 | iter 901 / 1327 | time 1641[s] | perplexity 118.05 | epoch 4 | iter 921 / 1327 | time 1648[s] | perplexity 103.90 | epoch 4 | iter 941 / 1327 | time 1654[s] | perplexity 113.97 | epoch 4 | iter 961 / 1327 | time 1661[s] | perplexity 114.04 | epoch 4 | iter 981 / 1327 | time 1668[s] | perplexity 108.00 | epoch 4 | iter 1001 / 1327 | time 1675[s] | perplexity 98.89 | epoch 4 | iter 1021 / 1327 | time 1681[s] | perplexity 113.60 | epoch 4 | iter 1041 / 1327 | time 1688[s] | perplexity 106.50 | epoch 4 | iter 1061 / 1327 | time 1695[s] | perplexity 89.44 | epoch 4 | iter 1081 / 1327 | time 1702[s] | perplexity 79.99 | epoch 4 | iter 1101 / 1327 | time 1708[s] | perplexity 81.84 | epoch 4 | iter 1121 / 1327 | time 1715[s] | perplexity 104.83 | epoch 4 | iter 1141 / 1327 | time 1722[s] | perplexity 102.74 | epoch 4 | iter 1161 / 1327 | time 1728[s] | perplexity 93.97 | epoch 4 | iter 1181 / 1327 | time 1735[s] | perplexity 98.14 | epoch 4 | iter 1201 / 1327 | time 1742[s] | perplexity 84.93 | epoch 4 | iter 1221 / 1327 | time 1749[s] | perplexity 76.64 | epoch 4 | iter 1241 / 1327 | time 1755[s] | perplexity 93.11 | epoch 4 | iter 1261 / 1327 | time 1762[s] | perplexity 95.51 | epoch 4 | iter 1281 / 1327 | time 1769[s] | perplexity 90.25 | epoch 4 | iter 1301 / 1327 | time 1776[s] | perplexity 112.19 | epoch 4 | iter 1321 / 1327 | time 1782[s] | perplexity 111.55 evaluating perplexity ... 0 / 235 1 / 235 2 / 235 3 / 235 4 / 235 5 / 235 6 / 235 7 / 235 8 / 235 9 / 235 10 / 235 11 / 235 12 / 235 13 / 235 14 / 235 15 / 235 16 / 235 17 / 235 18 / 235 19 / 235 20 / 235 21 / 235 22 / 235 23 / 235 24 / 235 25 / 235 26 / 235 27 / 235 28 / 235 29 / 235 30 / 235 31 / 235 32 / 235 33 / 235 34 / 235 35 / 235 36 / 235 37 / 235 38 / 235 39 / 235 40 / 235 41 / 235 42 / 235 43 / 235 44 / 235 45 / 235 46 / 235 47 / 235 48 / 235 49 / 235 50 / 235 51 / 235 52 / 235 53 / 235 54 / 235 55 / 235 56 / 235 57 / 235 58 / 235 59 / 235 60 / 235 61 / 235 62 / 235 63 / 235 64 / 235 65 / 235 66 / 235 67 / 235 68 / 235 69 / 235 70 / 235 71 / 235 72 / 235 73 / 235 74 / 235 75 / 235 76 / 235 77 / 235 78 / 235 79 / 235 80 / 235 81 / 235 82 / 235 83 / 235 84 / 235 85 / 235 86 / 235 87 / 235 88 / 235 89 / 235 90 / 235 91 / 235 92 / 235 93 / 235 94 / 235 95 / 235 96 / 235 97 / 235 98 / 235 99 / 235 100 / 235 101 / 235 102 / 235 103 / 235 104 / 235 105 / 235 106 / 235 107 / 235 108 / 235 109 / 235 110 / 235 111 / 235 112 / 235 113 / 235 114 / 235 115 / 235 116 / 235 117 / 235 118 / 235 119 / 235 120 / 235 121 / 235 122 / 235 123 / 235 124 / 235 125 / 235 126 / 235 127 / 235 128 / 235 129 / 235 130 / 235 131 / 235 132 / 235 133 / 235 134 / 235 135 / 235 136 / 235 137 / 235 138 / 235 139 / 235 140 / 235 141 / 235 142 / 235 143 / 235 144 / 235 145 / 235 146 / 235 147 / 235 148 / 235 149 / 235 150 / 235 151 / 235 152 / 235 153 / 235 154 / 235 155 / 235 156 / 235 157 / 235 158 / 235 159 / 235 160 / 235 161 / 235 162 / 235 163 / 235 164 / 235 165 / 235 166 / 235 167 / 235 168 / 235 169 / 235 170 / 235 171 / 235 172 / 235 173 / 235 174 / 235 175 / 235 176 / 235 177 / 235 178 / 235 179 / 235 180 / 235 181 / 235 182 / 235 183 / 235 184 / 235 185 / 235 186 / 235 187 / 235 188 / 235 189 / 235 190 / 235 191 / 235 192 / 235 193 / 235 194 / 235 195 / 235 196 / 235 197 / 235 198 / 235 199 / 235 200 / 235 201 / 235 202 / 235 203 / 235 204 / 235 205 / 235 206 / 235 207 / 235 208 / 235 209 / 235 210 / 235 211 / 235 212 / 235 213 / 235 214 / 235 215 / 235 216 / 235 217 / 235 218 / 235 219 / 235 220 / 235 221 / 235 222 / 235 223 / 235 224 / 235 225 / 235 226 / 235 227 / 235 228 / 235 229 / 235 230 / 235 231 / 235 232 / 235 233 / 235 234 / 235 test perplexity: 137.9052746711987
%sh top -b -n1
top - 11:21:03 up 14 min, 1 user, load average: 7.33, 3.84, 1.59 Tasks: 147 total, 2 running, 92 sleeping, 0 stopped, 0 zombie Cpu(s): 11.4%us, 11.9%sy, 0.0%ni, 76.1%id, 0.1%wa, 0.0%hi, 0.0%si, 0.5%st Mem: 15394144k total, 3463012k used, 11931132k free, 34120k buffers Swap: 0k total, 0k used, 0k free, 2005276k cached PID USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND 3676 ec2-user 20 0 1058m 408m 18m R 655.4 2.7 25:58.79 python3 2604 root 20 0 6520 96 0 S 2.0 0.0 0:00.22 rngd 3403 ec2-user 20 0 6695m 557m 25m S 2.0 3.7 0:31.42 java 3483 ec2-user 20 0 4869m 91m 18m S 2.0 0.6 0:02.16 java 3634 ec2-user 20 0 5195m 115m 18m S 2.0 0.8 0:02.97 java 3782 ec2-user 20 0 15360 2292 1960 R 2.0 0.0 0:00.01 top 1 root 20 0 19688 2544 2212 S 0.0 0.0 0:01.17 init 2 root 20 0 0 0 0 S 0.0 0.0 0:00.00 kthreadd 3 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/0:0 4 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/0:0H 6 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 mm_percpu_wq 7 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/0 8 root 20 0 0 0 0 I 0.0 0.0 0:00.09 rcu_sched 9 root 20 0 0 0 0 I 0.0 0.0 0:00.00 rcu_bh 10 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/0 11 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/0 12 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/0 13 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/1 14 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/1 15 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/1 16 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/1 17 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/1:0 18 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/1:0H 19 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/2 20 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/2 21 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/2 22 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/2 23 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/2:0 24 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/2:0H 25 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/3 26 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/3 27 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/3 28 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/3 29 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/3:0 30 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/3:0H 31 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/4 32 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/4 33 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/4 34 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/4 35 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/4:0 36 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/4:0H 37 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/5 38 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/5 39 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/5 40 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/5 41 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/5:0 42 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/5:0H 43 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/6 44 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/6 45 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/6 46 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/6 47 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/6:0 48 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/6:0H 49 root 20 0 0 0 0 S 0.0 0.0 0:00.00 cpuhp/7 50 root RT 0 0 0 0 S 0.0 0.0 0:00.00 watchdog/7 51 root RT 0 0 0 0 S 0.0 0.0 0:00.00 migration/7 52 root 20 0 0 0 0 S 0.0 0.0 0:00.00 ksoftirqd/7 53 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/7:0 54 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/7:0H 55 root 20 0 0 0 0 S 0.0 0.0 0:00.00 kdevtmpfs 56 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 netns 57 root 20 0 0 0 0 I 0.0 0.0 0:00.47 kworker/u30:1 63 root 20 0 0 0 0 S 0.0 0.0 0:00.00 xenbus 64 root 20 0 0 0 0 S 0.0 0.0 0:00.01 xenwatch 65 root 20 0 0 0 0 I 0.0 0.0 0:00.02 kworker/0:1 74 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/2:1 75 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/3:1 99 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/5:1 107 root 20 0 0 0 0 I 0.0 0.0 0:00.22 kworker/u30:5 161 root 20 0 0 0 0 I 0.0 0.0 0:00.05 kworker/7:1 218 root 20 0 0 0 0 S 0.0 0.0 0:00.00 khungtaskd 219 root 20 0 0 0 0 S 0.0 0.0 0:00.00 oom_reaper 220 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 writeback 221 root 20 0 0 0 0 S 0.0 0.0 0:00.00 kcompactd0 223 root 25 5 0 0 0 S 0.0 0.0 0:00.00 ksmd 224 root 39 19 0 0 0 S 0.0 0.0 0:00.00 khugepaged 225 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 crypto 226 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kintegrityd 228 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kblockd 238 root 20 0 0 0 0 I 0.0 0.0 0:00.02 kworker/6:1 581 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 md 584 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 edac-poller 590 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/4:1 728 root 20 0 0 0 0 S 0.0 0.0 0:00.02 kauditd 734 root 20 0 0 0 0 S 0.0 0.0 0:00.00 kswapd0 830 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kthrotld 890 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kstrp 1680 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 ata_sff 1693 root 20 0 0 0 0 S 0.0 0.0 0:00.00 scsi_eh_0 1694 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 scsi_tmf_0 1697 root 20 0 0 0 0 S 0.0 0.0 0:00.00 scsi_eh_1 1698 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 scsi_tmf_1 1749 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/u31:0 1770 root 20 0 0 0 0 S 0.0 0.0 0:00.04 jbd2/xvda1-8 1771 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 ext4-rsv-conver 1799 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/0:1H 1817 root 20 0 11476 2652 1752 S 0.0 0.0 0:00.07 udevd 1962 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 ixgbevf 2076 root 20 0 0 0 0 I 0.0 0.0 0:00.01 kworker/1:2 2158 root 20 0 106m 620 388 S 0.0 0.0 0:00.00 lvmetad 2167 root 20 0 27196 200 4 S 0.0 0.0 0:00.00 lvmpolld 2219 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 ipv6_addrconf 2356 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/5:1H 2378 root 20 0 9412 1304 816 S 0.0 0.0 0:00.00 dhclient 2493 root 20 0 9412 1920 1444 S 0.0 0.0 0:00.00 dhclient 2535 root 20 0 615m 15m 11m S 0.0 0.1 0:00.09 amazon-ssm-agen 2547 root 16 -4 53004 2408 2024 S 0.0 0.0 0:00.00 auditd 2569 root 20 0 241m 2828 2444 S 0.0 0.0 0:00.01 rsyslogd 2590 root 20 0 92900 2576 2312 S 0.0 0.0 0:00.06 irqbalance 2622 rpc 20 0 35364 2264 1868 S 0.0 0.0 0:00.01 rpcbind 2643 rpcuser 20 0 39932 3300 2500 S 0.0 0.0 0:00.00 rpc.statd 2674 dbus 20 0 21844 232 0 S 0.0 0.0 0:00.00 dbus-daemon 2709 root 20 0 4396 1400 1260 S 0.0 0.0 0:00.00 acpid 2805 root 20 0 80588 2664 1840 S 0.0 0.0 0:00.00 sshd 2816 ntp 20 0 113m 5264 4456 S 0.0 0.0 0:00.06 ntpd 2837 root 20 0 89628 3888 2168 S 0.0 0.0 0:00.01 sendmail 2846 smmsp 20 0 81088 3904 2424 S 0.0 0.0 0:00.00 sendmail 2858 root 20 0 118m 2460 1864 S 0.0 0.0 0:00.00 crond 2872 root 20 0 19188 168 0 S 0.0 0.0 0:00.00 atd 2895 root 20 0 6508 1768 1644 S 0.0 0.0 0:00.02 agetty 2896 root 20 0 4360 1504 1408 S 0.0 0.0 0:00.02 mingetty 2900 root 20 0 4360 1504 1408 S 0.0 0.0 0:00.00 mingetty 2903 root 20 0 4360 1464 1372 S 0.0 0.0 0:00.00 mingetty 2905 root 20 0 4360 1452 1360 S 0.0 0.0 0:00.00 mingetty 2907 root 20 0 4360 1456 1360 S 0.0 0.0 0:00.00 mingetty 2909 root 20 0 4360 1492 1400 S 0.0 0.0 0:00.00 mingetty 2919 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/2:1H 2999 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/7:1H 3077 root 20 0 975m 72m 44m S 0.0 0.5 0:01.51 dockerd 3091 root 20 0 629m 23m 18m S 0.0 0.2 0:02.61 docker-containe 3108 root 20 0 11472 2252 1344 S 0.0 0.0 0:00.00 udevd 3109 root 20 0 11472 2228 1328 S 0.0 0.0 0:00.00 udevd 3240 root 20 0 117m 7184 6096 S 0.0 0.0 0:00.00 sshd 3242 ec2-user 20 0 117m 4004 2920 S 0.0 0.0 0:00.16 sshd 3243 ec2-user 20 0 112m 3504 3088 S 0.0 0.0 0:00.03 bash 3321 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/6:1H 3378 ec2-user 20 0 19996 2528 2260 S 0.0 0.0 0:00.00 tmux 3380 ec2-user 20 0 22444 3096 2516 S 0.0 0.0 0:00.20 tmux 3381 ec2-user 20 0 112m 3440 3048 S 0.0 0.0 0:00.00 bash 3402 ec2-user 20 0 105m 2280 2048 S 0.0 0.0 0:00.00 make 3470 ec2-user 20 0 110m 3060 2808 S 0.0 0.0 0:00.00 interpreter.sh 3482 ec2-user 20 0 110m 2268 2016 S 0.0 0.0 0:00.00 interpreter.sh 3572 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/1:1H 3621 ec2-user 20 0 110m 3020 2768 S 0.0 0.0 0:00.00 interpreter.sh 3633 ec2-user 20 0 110m 2200 1948 S 0.0 0.0 0:00.00 interpreter.sh 3756 root 0 -20 0 0 0 I 0.0 0.0 0:00.00 kworker/4:1H 3758 root 20 0 0 0 0 I 0.0 0.0 0:00.00 kworker/u30:0
6.5.4 より良い RNNLM の実装
- LSTM レイヤの多層化
- Dropout を使用
- 重み共有
%python3 import sys sys.path.append('/tmp/deep-learning-from-scratch-2') from common.time_layers import TimeEmbedding, TimeDropout, TimeAffine, TimeSoftmaxWithLoss from common.np import * from common.base_model import BaseModel class BetterRnnlm(BaseModel): def __init__(self, vocab_size=10000, wordvec_size=650, hidden_size=650, dropout_ratio=0.5): V, D, H = vocab_size, wordvec_size, hidden_size rn = np.random.randn embed_W = (rn(V, D) / 100).astype('f') lstm_Wx1 = (rn(D, 4*H) / np.sqrt(D)).astype('f') lstm_Wh1 = (rn(H, 4*H) / np.sqrt(H)).astype('f') lstm_b1 = np.zeros(4*H).astype('f') lstm_Wx2 = (rn(D, 4*H) / np.sqrt(D)).astype('f') lstm_Wh2 = (rn(H, 4*H) / np.sqrt(H)).astype('f') lstm_b2 = np.zeros(4*H).astype('f') affine_b = np.zeros(V).astype('f') self.layers = [ TimeEmbedding(embed_W), TimeDropout(dropout_ratio), TimeLSTM(lstm_Wx1, lstm_Wh1, lstm_b1, stateful=True), TimeDropout(dropout_ratio), TimeLSTM(lstm_Wx2, lstm_Wh2, lstm_b2, stateful=True), TimeDropout(dropout_ratio), TimeAffine(embed_W.T, affine_b) ] self.loss_layer = TimeSoftmaxWithLoss() self.lstm_layers = [self.layers[2], self.layers[4]] self.drop_layers = [self.layers[1], self.layers[3], self.layers[5]] self.params, self.grads = [], [] for layer in self.layers: self.params += layer.params self.grads += layer.grads def predict(self, xs, train_flg=False): for layer in self.drop_layers: layer.train_flg = train_flg for layer in self.layers: xs = layer.forward(xs) return xs def forward(self, xs, ts, train_flg=True): score = self.predict(xs, train_flg) loss = self.loss_layer.forward(score, ts) return loss def backward(self, dout=1): dout = self.loss_layer.backward(dout) for layer in reversed(self.layers): dout = layer.backward(dout) return dout def reset_state(self): for layer in self.lstm_layers: layer.reset_state()
%python3 from common.optimizer import SGD from common.trainer import RnnlmTrainer from common.util import eval_perplexity from dataset import ptb batch_size = 20 wordvec_size = 650 hidden_size = 650 time_size = 35 lr = 20.0 max_epoch = 40 max_grad = 0.25 dropout = 0.5 corpus, word_to_id, id_to_word = ptb.load_data('train') corpus_val, _, _ = ptb.load_data('val') corpus_test, _, _ = ptb.load_data('test') vocab_size = len(word_to_id) xs = corpus[:-1] ts = corpus[1:] model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout) optimizer = SGD(lr) trainer = RnnlmTrainer(model, optimizer) best_ppl = float('inf') for epoch in range(max_epoch): trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size, time_size=time_size, max_grad=max_grad) model.reset_state() ppl = eval_perplexity(model, corpus_val) print('valid perplexity: ', ppl) if best_ppl > ppl: best_ppl = ppl model.save_params() else: lr /= 4.0 optimizer.lr = lr model.reset_state() print('-' * 50)
| epoch 1 | iter 1 / 1327 | time 1[s] | perplexity 10000.27 | epoch 1 | iter 21 / 1327 | time 30[s] | perplexity 3663.54
レイヤーが増えるとGPUのほうがいい?
CUDA: バージョン確認
%sh nvidia-smi nvcc --version
Sun Nov 18 12:35:56 2018 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 396.37 Driver Version: 396.37 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla V100-SXM2... On | 00000000:00:1B.0 Off | 0 | | N/A 34C P0 25W / 300W | 0MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 1 Tesla V100-SXM2... On | 00000000:00:1C.0 Off | 0 | | N/A 33C P0 24W / 300W | 0MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 2 Tesla V100-SXM2... On | 00000000:00:1D.0 Off | 0 | | N/A 32C P0 24W / 300W | 0MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 3 Tesla V100-SXM2... On | 00000000:00:1E.0 Off | 0 | | N/A 33C P0 27W / 300W | 0MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+ nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2017 NVIDIA Corporation Built on Fri_Sep__1_21:08:03_CDT_2017 Cuda compilation tools, release 9.0, V9.0.176
cupyのインストール
%sh sudo pip-3.6 install cupy-cuda90
Collecting cupy-cuda90 Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB) Collecting fastrlock>=0.3 (from cupy-cuda90) Downloading https://files.pythonhosted.org/packages/b5/93/a7efbd39eac46c137500b37570c31dedc2d31a8ff4949fcb90bda5bc5f16/fastrlock-0.4-cp36-cp36m-manylinux1_x86_64.whl Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90) Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90) Installing collected packages: fastrlock, cupy-cuda90 Successfully installed cupy-cuda90-5.0.0 fastrlock-0.4 You are using pip version 9.0.3, however version 18.1 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
%python3 from common.optimizer import SGD from common.trainer import RnnlmTrainer from common.util import eval_perplexity from dataset import ptb batch_size = 20 wordvec_size = 650 hidden_size = 650 time_size = 35 lr = 20.0 max_epoch = 40 max_grad = 0.25 dropout = 0.5 corpus, word_to_id, id_to_word = ptb.load_data('train') corpus_val, _, _ = ptb.load_data('val') corpus_test, _, _ = ptb.load_data('test') vocab_size = len(word_to_id) xs = corpus[:-1] ts = corpus[1:] model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout) optimizer = SGD(lr) trainer = RnnlmTrainer(model, optimizer) best_ppl = float('inf') for epoch in range(max_epoch): trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size, time_size=time_size, max_grad=max_grad) model.reset_state() ppl = eval_perplexity(model, corpus_val) print('valid perplexity: ', ppl) if best_ppl > ppl: best_ppl = ppl model.save_params() else: lr /= 4.0 optimizer.lr = lr model.reset_state() print('-' * 50)
%sh nvidia-smi
Sun Nov 18 12:48:46 2018 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 396.37 Driver Version: 396.37 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla V100-SXM2... On | 00000000:00:1B.0 Off | 0 | | N/A 37C P0 58W / 300W | 2821MiB / 16160MiB | 10% Default | +-------------------------------+----------------------+----------------------+ | 1 Tesla V100-SXM2... On | 00000000:00:1C.0 Off | 0 | | N/A 33C P0 36W / 300W | 11MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 2 Tesla V100-SXM2... On | 00000000:00:1D.0 Off | 0 | | N/A 31C P0 38W / 300W | 11MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 3 Tesla V100-SXM2... On | 00000000:00:1E.0 Off | 0 | | N/A 32C P0 42W / 300W | 11MiB / 16160MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | 0 5761 C python3 2810MiB | +-----------------------------------------------------------------------------+
- 競り負けて途中で落ちてしまった 🙃
- 今回のケースは複数GPUは必要なさそう
- メモリもだいぶ余っている
%md