ゼロから作る Deep Learning 2/LSTM

Posted on

6章: ゲート付き RNN

ゼロから作るDeep Learning (2)の読書メモです。5章ではシンプルな RNN を使って言語モデルを実装しました。6章では LSTM などのゲートと呼ばれる仕組みを加えた RNN を実装することで長期的な依存関係を学習できるニューラルネットワークを実装していきます。

参考実装

%sh
rm -rf /tmp/deep-learning-from-scratch-2
git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2 /tmp/deep-learning-from-scratch-2
Cloning into '/tmp/deep-learning-from-scratch-2'...

numpy と matplotlib のインストール

%sh
pip3 install numpy matplotlib
Collecting numpy
  Downloading https://files.pythonhosted.org/packages/86/04/bd774106ae0ae1ada68c67efe89f1a16b2aa373cc2db15d974002a9f136d/numpy-1.15.4-cp35-cp35m-manylinux1_x86_64.whl (13.8MB)
Collecting matplotlib
  Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB)
Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/71/e8/6777f6624681c8b9701a8a0a5654f3eb56919a01a78e12bf3c73f5a3c714/pyparsing-2.3.0-py2.py3-none-any.whl (59kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting python-dateutil>=2.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB)
Collecting kiwisolver>=1.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB)
Collecting six (from cycler>=0.10->matplotlib)
  Downloading https://files.pythonhosted.org/packages/67/4b/141a581104b1f6397bfa78ac9d43d8ad29a7ca43ea90a2d863fe3056e86a/six-1.11.0-py2.py3-none-any.whl
Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib)
Installing collected packages: numpy, pyparsing, six, cycler, python-dateutil, kiwisolver, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.15.4 pyparsing-2.3.0 python-dateutil-2.7.5 six-1.11.0
You are using pip version 8.1.1, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

%sh
pip3 install seaborn
Collecting seaborn
  Downloading https://files.pythonhosted.org/packages/a8/76/220ba4420459d9c4c9c9587c6ce607bf56c25b3d3d2de62056efe482dadc/seaborn-0.9.0-py3-none-any.whl (208kB)
Collecting scipy>=0.14.0 (from seaborn)
  Downloading https://files.pythonhosted.org/packages/cd/32/5196b64476bd41d596a8aba43506e2403e019c90e1a3dfc21d51b83db5a6/scipy-1.1.0-cp35-cp35m-manylinux1_x86_64.whl (33.1MB)
Collecting pandas>=0.15.2 (from seaborn)
  Downloading https://files.pythonhosted.org/packages/5d/d4/6e9c56a561f1d27407bf29318ca43f36ccaa289271b805a30034eb3a8ec4/pandas-0.23.4-cp35-cp35m-manylinux1_x86_64.whl (8.7MB)
Requirement already satisfied (use --upgrade to upgrade): matplotlib>=1.4.3 in /usr/local/lib/python3.5/dist-packages (from seaborn)
Requirement already satisfied (use --upgrade to upgrade): numpy>=1.9.3 in /usr/local/lib/python3.5/dist-packages (from seaborn)
Collecting pytz>=2011k (from pandas>=0.15.2->seaborn)
  Downloading https://files.pythonhosted.org/packages/f8/0e/2365ddc010afb3d79147f1dd544e5ee24bf4ece58ab99b16fbb465ce6dc0/pytz-2018.7-py2.py3-none-any.whl (506kB)
Requirement already satisfied (use --upgrade to upgrade): python-dateutil>=2.5.0 in /usr/local/lib/python3.5/dist-packages (from pandas>=0.15.2->seaborn)
Requirement already satisfied (use --upgrade to upgrade): pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
Requirement already satisfied (use --upgrade to upgrade): kiwisolver>=1.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
Requirement already satisfied (use --upgrade to upgrade): cycler>=0.10 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
Requirement already satisfied (use --upgrade to upgrade): six>=1.5 in /usr/local/lib/python3.5/dist-packages (from python-dateutil>=2.5.0->pandas>=0.15.2->seaborn)
Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib>=1.4.3->seaborn)
Installing collected packages: scipy, pytz, pandas, seaborn
Successfully installed pandas-0.23.4 pytz-2018.7 scipy-1.1.0 seaborn-0.9.0
You are using pip version 8.1.1, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

6.1.3: 勾配消失もしくは勾配爆発の原因

%python3
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def plot_tanh(x):
    y = np.tanh(x)
    plt.plot(x, y, label='tanh')

def plot_dydx(x):
    y = 1 / (np.cosh(x)**2)
    plt.plot(x, y, '--', label='dy/dx')

sns.set()
x = np.arange(-5, 5, 0.2)
plt.xlabel('x')
plt.ylabel('y')
plot_tanh(x)
plot_dydx(x)

plt.legend(loc='lower right', fontsize=18, borderaxespad=1)
<matplotlib.legend.Legend object at 0x7f8c13198ba8>

  • \(y=tanh(x)\) を微分した関数は \(x\) が 0 から遠ざかるにつれて小さくなるので逆伝搬で tanh ノードを通るたびに出力される値は小さくなる
  • ReLU を使うことで勾配消失を抑えることができるらしい

%python3
N = 2
H = 3
T = 20

dh = np.ones((N, H))
np.random.seed(3)
Wh = np.random.randn(H, H)

norm_list = []
for t in range(T):
    dh = np.dot(dh, Wh.T)
    norm = np.sqrt(np.sum(dh**2)) / N
    norm_list.append(norm)

sns.set()
plt.xlabel('time step')
plt.ylabel('norm')
plt.plot(norm_list)
[<matplotlib.lines.Line2D object at 0x7f8bfca6c5f8>]

%python3
N = 2
H = 3
T = 20

dh = np.ones((N, H))
np.random.seed(3)
Wh = np.random.randn(H, H) * 0.5

norm_list = []
for t in range(T):
    dh = np.dot(dh, Wh.T)
    norm = np.sqrt(np.sum(dh**2)) / N
    norm_list.append(norm)

sns.set()
plt.xlabel('time step')
plt.ylabel('norm')
plt.plot(norm_list)
[<matplotlib.lines.Line2D object at 0x7f8c11092e80>]

  • MatMul ノードについても同様に計算を繰り返すと勾配爆発や勾配消失が起こる
  • 勾配爆発の対策
    • 勾配クリッピング
    • 勾配の L2 ノルムがしきい値を超えた場合に勾配を修正する
  • 勾配消失の対策
    • ゲートを導入する

ゲート

  • 各要素に対して次の状態としてどれだけ重要かを調整するもの
  • output ゲート

$$
o = σ(x_t W_{x}^{(o)} + h_{t-1}W_{h}^{(o)} + b^{(o)})
$$

  • output ゲートと tanh ノードの各要素の積(アマダール積)を出力
  • 各ゲートがそれぞれ重みを持っている
  • 行列演算を行うときは各ゲートの演算をまとめることができる

LSTMレイヤー: 実装

%python3
import sys
sys.path.append('/tmp/deep-learning-from-scratch-2')

from common import config
config.GPU = True
from common.np import *
from common.functions import sigmoid


class LSTM:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

    def forward(self, x, h_prev, c_prev):
        Wx, Wh, b = self.params
        N, H = h_prev.shape

        A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b

        f = A[:, :H]
        g = A[:, H:2*H]
        i = A[:, 2*H:3*H]
        o = A[:, 3*H:]

        f = sigmoid(f)
        g = np.tanh(g)
        i = sigmoid(i)
        o = sigmoid(o)

        c_next = f * c_prev + g * i
        h_next = o * np.tanh(c_next)

        self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)
        return h_next, c_next

    def backward(self, dh_next, dc_next):
        Wx, Wh, b = self.params
        x, h_prev, c_prev, i, f, g, o, c_next = self.cache

        tanh_c_next = np.tanh(c_next)

        ds = dc_next + (dh_next * o) * (1 - tanh_c_next ** 2)

        dc_prev = ds * f

        di = ds * g
        df = ds * c_prev
        do = dh_next * tanh_c_next
        dg = ds * i

        di *= i * (1 - i)
        df *= f * (1 - f)
        do *= o * (1 - o)
        dg *= (1 - g ** 2)

        dA = np.hstack((df, dg, di, do))

        dWh = np.dot(h_prev.T, dA)
        dWx = np.dot(x.T, dA)
        db = dA.sum(axis=0)

        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db

        dx = np.dot(dA, Wx.T)
        dh_prev = np.dot(dA, Wh.T)

        return dx, dh_prev, dc_prev
------------------------------------------------------------
                       GPU Mode (cupy)
------------------------------------------------------------

TimeLSTMレイヤー: 実装

%python3
class TimeLSTM:
    def __init__(self, Wx, Wh, b, stateful=False):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.layers = None

        self.h, self.c = None, None
        self.dh = None
        self.stateful = stateful

    def forward(self, xs):
        Wx, Wh, b = self.params
        N, T, D = xs.shape
        H = Wh.shape[0]

        self.layers = []
        hs = np.empty((N, T, H), dtype='f')

        if not self.stateful or self.h is None:
            self.h = np.zeros((N, H), dtype='f')
        if not self.stateful or self.c is None:
            self.c = np.zeros((N, H), dtype='f')

        for t in range(T):
            layer = LSTM(*self.params)
            self.h, self.c = layer.forward(xs[:, t, :], self.h, self.c)
            hs[:, t, :] = self.h

            self.layers.append(layer)

        return hs

    def backward(self, dhs):
        Wx, Wh, b = self.params
        N, T, H = dhs.shape
        D = Wx.shape[0]

        dxs = np.empty((N, T, D), dtype='f')
        dh, dc = 0, 0

        grads = [0, 0, 0]
        for t in reversed(range(T)):
            layer = self.layers[t]
            dx, dh, dc = layer.backward(dhs[:, t, :] + dh, dc)
            dxs[:, t, :] = dx
            for i, grad in enumerate(layer.grads):
                grads[i] += grad

        for i, grad in enumerate(grads):
            self.grads[i][...] = grad
        self.dh = dh
        return dxs

    def set_state(self, h, c=None):
        self.h, self.c = h, c

    def reset_state(self):
        self.h, self.c = None, None

RNNLM: 実装

%python3
import pickle
from common.time_layers import TimeSoftmaxWithLoss, TimeEmbedding, TimeAffine

class Rnnlm:
    def __init__(self, vocab_size=10000, wordvec_size=100, hidden_size=100):
        V, D, H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        
        embed_W = (rn(V, D) / 100).astype('f')
        lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f')
        lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f')
        lstm_b = np.zeros(4 * H).astype('f')
        affine_W = (rn(H, V) / np.sqrt(H)).astype('f')
        affine_b = np.zeros(V).astype('f')
        
        self.layers = [
            TimeEmbedding(embed_W),
            TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=True),
            TimeAffine(affine_W, affine_b)
        ]
        self.loss_layer = TimeSoftmaxWithLoss()
        self.lstm_layer = self.layers[1]
        
        self.params, self.grads = [], []
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads

    def predict(self, xs):
        for layer in self.layers:
            xs = layer.forward(xs)
        return xs
    
    def forward(self, xs, ts):
        score = self.predict(xs)
        loss = self.loss_layer.forward(score, ts)
        return loss
    
    def backward(self, dout=1):
        dout = self.loss_layer.backward(dout)
        for layer in reversed(self.layers):
            dout = layer.backward(dout)
        return dout
    
    def reset_state(self):
        self.lstm_layer.reset_state()
    
    def save_params(self, file_name='Rnnlm.pkl'):
        with open(file_name, 'wb') as f:
            pickle.dump(self.params, f)
    
    def load_params(self, file_name='Rnnlm.pkl'):
        with open(file_name, 'rb') as f:
            self.params = pickle.load(f)

CUDAを使う場合

%sh
nvidia-smi
Thu Nov 15 13:21:21 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla K80           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   56C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

%sh
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2017 NVIDIA Corporation
Built on Fri_Sep__1_21:08:03_CDT_2017
Cuda compilation tools, release 9.0, V9.0.176

%sh
sudo pip-3.6 install cupy-cuda90
Collecting cupy-cuda90
  Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB)
Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90)
Requirement already satisfied: fastrlock>=0.3 in /usr/local/lib64/python3.6/site-packages (from cupy-cuda90)
Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90)
Installing collected packages: cupy-cuda90
Successfully installed cupy-cuda90-5.0.0
You are using pip version 9.0.3, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

トレーニング実行

%python3

from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb

# hyper parameters
batch_size = 20
wordvec_size = 100
hidden_size = 100
time_size = 35
lr = 20.0
max_epoch = 4
max_grad = 0.25

# load dataset
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

# generate model
model = Rnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20)
trainer.plot(ylim=(0, 500))

model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)

model.save_params()
| epoch 1 |  iter 1 / 1327 | time 12[s] | perplexity 10000.52
| epoch 1 |  iter 21 / 1327 | time 18[s] | perplexity 3043.08
| epoch 1 |  iter 41 / 1327 | time 24[s] | perplexity 1223.89
| epoch 1 |  iter 61 / 1327 | time 30[s] | perplexity 988.30
| epoch 1 |  iter 81 / 1327 | time 36[s] | perplexity 812.58
| epoch 1 |  iter 101 / 1327 | time 42[s] | perplexity 685.93
| epoch 1 |  iter 121 / 1327 | time 48[s] | perplexity 628.84
| epoch 1 |  iter 141 / 1327 | time 54[s] | perplexity 607.25
| epoch 1 |  iter 161 / 1327 | time 60[s] | perplexity 593.01
| epoch 1 |  iter 181 / 1327 | time 66[s] | perplexity 572.24
| epoch 1 |  iter 201 / 1327 | time 72[s] | perplexity 516.94
| epoch 1 |  iter 221 / 1327 | time 78[s] | perplexity 500.29
| epoch 1 |  iter 241 / 1327 | time 84[s] | perplexity 441.31
| epoch 1 |  iter 261 / 1327 | time 90[s] | perplexity 466.61
| epoch 1 |  iter 281 / 1327 | time 95[s] | perplexity 457.95
| epoch 1 |  iter 301 / 1327 | time 101[s] | perplexity 400.09
| epoch 1 |  iter 321 / 1327 | time 107[s] | perplexity 354.01
| epoch 1 |  iter 341 / 1327 | time 113[s] | perplexity 414.76
| epoch 1 |  iter 361 / 1327 | time 119[s] | perplexity 413.01
| epoch 1 |  iter 381 / 1327 | time 125[s] | perplexity 338.24
| epoch 1 |  iter 401 / 1327 | time 131[s] | perplexity 358.44
| epoch 1 |  iter 421 / 1327 | time 137[s] | perplexity 351.17
| epoch 1 |  iter 441 / 1327 | time 143[s] | perplexity 332.46
| epoch 1 |  iter 461 / 1327 | time 149[s] | perplexity 332.15
| epoch 1 |  iter 481 / 1327 | time 155[s] | perplexity 308.20
| epoch 1 |  iter 501 / 1327 | time 161[s] | perplexity 315.23
| epoch 1 |  iter 521 / 1327 | time 167[s] | perplexity 307.01
| epoch 1 |  iter 541 / 1327 | time 173[s] | perplexity 321.42
| epoch 1 |  iter 561 / 1327 | time 179[s] | perplexity 292.32
| epoch 1 |  iter 581 / 1327 | time 184[s] | perplexity 264.95
| epoch 1 |  iter 601 / 1327 | time 190[s] | perplexity 342.06
| epoch 1 |  iter 621 / 1327 | time 196[s] | perplexity 317.83
| epoch 1 |  iter 641 / 1327 | time 202[s] | perplexity 288.22
| epoch 1 |  iter 661 / 1327 | time 208[s] | perplexity 273.88
| epoch 1 |  iter 681 / 1327 | time 214[s] | perplexity 230.15
| epoch 1 |  iter 701 / 1327 | time 220[s] | perplexity 253.12
| epoch 1 |  iter 721 / 1327 | time 226[s] | perplexity 265.56
| epoch 1 |  iter 741 / 1327 | time 232[s] | perplexity 224.49
| epoch 1 |  iter 761 / 1327 | time 238[s] | perplexity 236.62
| epoch 1 |  iter 781 / 1327 | time 244[s] | perplexity 224.24
| epoch 1 |  iter 801 / 1327 | time 250[s] | perplexity 243.74
| epoch 1 |  iter 821 / 1327 | time 256[s] | perplexity 227.72
| epoch 1 |  iter 841 / 1327 | time 262[s] | perplexity 231.35
| epoch 1 |  iter 861 / 1327 | time 267[s] | perplexity 223.62
| epoch 1 |  iter 881 / 1327 | time 273[s] | perplexity 206.62
| epoch 1 |  iter 901 / 1327 | time 279[s] | perplexity 256.72
| epoch 1 |  iter 921 / 1327 | time 285[s] | perplexity 227.62
| epoch 1 |  iter 941 / 1327 | time 291[s] | perplexity 232.06
| epoch 1 |  iter 961 / 1327 | time 297[s] | perplexity 244.22
| epoch 1 |  iter 981 / 1327 | time 303[s] | perplexity 230.51
| epoch 1 |  iter 1001 / 1327 | time 309[s] | perplexity 194.96
| epoch 1 |  iter 1021 / 1327 | time 315[s] | perplexity 226.91
| epoch 1 |  iter 1041 / 1327 | time 321[s] | perplexity 209.66
| epoch 1 |  iter 1061 / 1327 | time 327[s] | perplexity 199.37
| epoch 1 |  iter 1081 / 1327 | time 333[s] | perplexity 170.13
| epoch 1 |  iter 1101 / 1327 | time 339[s] | perplexity 191.02
| epoch 1 |  iter 1121 / 1327 | time 345[s] | perplexity 231.76
| epoch 1 |  iter 1141 / 1327 | time 351[s] | perplexity 209.41
| epoch 1 |  iter 1161 / 1327 | time 356[s] | perplexity 200.44
| epoch 1 |  iter 1181 / 1327 | time 362[s] | perplexity 192.77
| epoch 1 |  iter 1201 / 1327 | time 368[s] | perplexity 165.20
| epoch 1 |  iter 1221 / 1327 | time 374[s] | perplexity 160.68
| epoch 1 |  iter 1241 / 1327 | time 380[s] | perplexity 187.82
| epoch 1 |  iter 1261 / 1327 | time 386[s] | perplexity 173.45
| epoch 1 |  iter 1281 / 1327 | time 392[s] | perplexity 179.21
| epoch 1 |  iter 1301 / 1327 | time 398[s] | perplexity 224.94
| epoch 1 |  iter 1321 / 1327 | time 404[s] | perplexity 211.44
| epoch 2 |  iter 1 / 1327 | time 406[s] | perplexity 225.14
| epoch 2 |  iter 21 / 1327 | time 412[s] | perplexity 204.59
| epoch 2 |  iter 41 / 1327 | time 418[s] | perplexity 191.89
| epoch 2 |  iter 61 / 1327 | time 424[s] | perplexity 178.07
| epoch 2 |  iter 81 / 1327 | time 430[s] | perplexity 160.75
| epoch 2 |  iter 101 / 1327 | time 436[s] | perplexity 153.10
| epoch 2 |  iter 121 / 1327 | time 441[s] | perplexity 162.32
| epoch 2 |  iter 141 / 1327 | time 447[s] | perplexity 179.64
| epoch 2 |  iter 161 / 1327 | time 453[s] | perplexity 194.35
| epoch 2 |  iter 181 / 1327 | time 459[s] | perplexity 202.43
| epoch 2 |  iter 201 / 1327 | time 465[s] | perplexity 188.74
| epoch 2 |  iter 221 / 1327 | time 471[s] | perplexity 185.50
| epoch 2 |  iter 241 / 1327 | time 477[s] | perplexity 178.58
| epoch 2 |  iter 261 / 1327 | time 483[s] | perplexity 186.30
| epoch 2 |  iter 281 / 1327 | time 489[s] | perplexity 187.32
| epoch 2 |  iter 301 / 1327 | time 495[s] | perplexity 167.80
| epoch 2 |  iter 321 / 1327 | time 501[s] | perplexity 139.48
| epoch 2 |  iter 341 / 1327 | time 507[s] | perplexity 171.67
| epoch 2 |  iter 361 / 1327 | time 512[s] | perplexity 197.60
| epoch 2 |  iter 381 / 1327 | time 518[s] | perplexity 154.82
| epoch 2 |  iter 401 / 1327 | time 524[s] | perplexity 169.48
| epoch 2 |  iter 421 / 1327 | time 530[s] | perplexity 155.43
| epoch 2 |  iter 441 / 1327 | time 536[s] | perplexity 164.11
| epoch 2 |  iter 461 / 1327 | time 542[s] | perplexity 158.15
| epoch 2 |  iter 481 / 1327 | time 548[s] | perplexity 158.01
| epoch 2 |  iter 501 / 1327 | time 554[s] | perplexity 169.49
| epoch 2 |  iter 521 / 1327 | time 560[s] | perplexity 172.82
| epoch 2 |  iter 541 / 1327 | time 566[s] | perplexity 175.41
| epoch 2 |  iter 561 / 1327 | time 572[s] | perplexity 157.42
| epoch 2 |  iter 581 / 1327 | time 578[s] | perplexity 140.03
| epoch 2 |  iter 601 / 1327 | time 584[s] | perplexity 192.35
| epoch 2 |  iter 621 / 1327 | time 589[s] | perplexity 182.63
| epoch 2 |  iter 641 / 1327 | time 595[s] | perplexity 165.24
| epoch 2 |  iter 661 / 1327 | time 601[s] | perplexity 154.49
| epoch 2 |  iter 681 / 1327 | time 607[s] | perplexity 129.85
| epoch 2 |  iter 701 / 1327 | time 613[s] | perplexity 153.00
| epoch 2 |  iter 721 / 1327 | time 619[s] | perplexity 160.11
| epoch 2 |  iter 741 / 1327 | time 625[s] | perplexity 134.15
| epoch 2 |  iter 761 / 1327 | time 631[s] | perplexity 132.44
| epoch 2 |  iter 781 / 1327 | time 637[s] | perplexity 135.70
| epoch 2 |  iter 801 / 1327 | time 643[s] | perplexity 147.10
| epoch 2 |  iter 821 / 1327 | time 649[s] | perplexity 145.80
| epoch 2 |  iter 841 / 1327 | time 655[s] | perplexity 144.89
| epoch 2 |  iter 861 / 1327 | time 660[s] | perplexity 146.55
| epoch 2 |  iter 881 / 1327 | time 666[s] | perplexity 129.88
| epoch 2 |  iter 901 / 1327 | time 672[s] | perplexity 167.04
| epoch 2 |  iter 921 / 1327 | time 678[s] | perplexity 146.77
| epoch 2 |  iter 941 / 1327 | time 684[s] | perplexity 152.44
| epoch 2 |  iter 961 / 1327 | time 690[s] | perplexity 164.83
| epoch 2 |  iter 981 / 1327 | time 696[s] | perplexity 153.86
| epoch 2 |  iter 1001 / 1327 | time 702[s] | perplexity 132.51
| epoch 2 |  iter 1021 / 1327 | time 708[s] | perplexity 155.63
| epoch 2 |  iter 1041 / 1327 | time 714[s] | perplexity 143.63
| epoch 2 |  iter 1061 / 1327 | time 720[s] | perplexity 130.54
| epoch 2 |  iter 1081 / 1327 | time 726[s] | perplexity 111.60
| epoch 2 |  iter 1101 / 1327 | time 732[s] | perplexity 120.41
| epoch 2 |  iter 1121 / 1327 | time 738[s] | perplexity 153.38
| epoch 2 |  iter 1141 / 1327 | time 743[s] | perplexity 143.03
| epoch 2 |  iter 1161 / 1327 | time 749[s] | perplexity 132.85
| epoch 2 |  iter 1181 / 1327 | time 755[s] | perplexity 135.65
| epoch 2 |  iter 1201 / 1327 | time 761[s] | perplexity 113.97
| epoch 2 |  iter 1221 / 1327 | time 767[s] | perplexity 108.89
| epoch 2 |  iter 1241 / 1327 | time 773[s] | perplexity 130.19
| epoch 2 |  iter 1261 / 1327 | time 779[s] | perplexity 125.34
| epoch 2 |  iter 1281 / 1327 | time 785[s] | perplexity 122.60
| epoch 2 |  iter 1301 / 1327 | time 791[s] | perplexity 159.44
| epoch 2 |  iter 1321 / 1327 | time 797[s] | perplexity 152.86
| epoch 3 |  iter 1 / 1327 | time 799[s] | perplexity 161.38
| epoch 3 |  iter 21 / 1327 | time 805[s] | perplexity 145.38
| epoch 3 |  iter 41 / 1327 | time 811[s] | perplexity 137.33
| epoch 3 |  iter 61 / 1327 | time 817[s] | perplexity 128.55
| epoch 3 |  iter 81 / 1327 | time 822[s] | perplexity 117.85
| epoch 3 |  iter 101 / 1327 | time 828[s] | perplexity 105.79
| epoch 3 |  iter 121 / 1327 | time 834[s] | perplexity 115.92
| epoch 3 |  iter 141 / 1327 | time 840[s] | perplexity 127.47
| epoch 3 |  iter 161 / 1327 | time 846[s] | perplexity 143.02
| epoch 3 |  iter 181 / 1327 | time 852[s] | perplexity 152.12
| epoch 3 |  iter 201 / 1327 | time 858[s] | perplexity 141.67
| epoch 3 |  iter 221 / 1327 | time 864[s] | perplexity 141.38
| epoch 3 |  iter 241 / 1327 | time 870[s] | perplexity 135.04
| epoch 3 |  iter 261 / 1327 | time 876[s] | perplexity 139.20
| epoch 3 |  iter 281 / 1327 | time 881[s] | perplexity 142.71
| epoch 3 |  iter 301 / 1327 | time 887[s] | perplexity 125.64
| epoch 3 |  iter 321 / 1327 | time 893[s] | perplexity 102.47
| epoch 3 |  iter 341 / 1327 | time 899[s] | perplexity 125.40
| epoch 3 |  iter 361 / 1327 | time 905[s] | perplexity 154.14
| epoch 3 |  iter 381 / 1327 | time 911[s] | perplexity 115.68
| epoch 3 |  iter 401 / 1327 | time 917[s] | perplexity 130.53
| epoch 3 |  iter 421 / 1327 | time 923[s] | perplexity 113.71
| epoch 3 |  iter 441 / 1327 | time 929[s] | perplexity 123.27
| epoch 3 |  iter 461 / 1327 | time 935[s] | perplexity 117.73
| epoch 3 |  iter 481 / 1327 | time 941[s] | perplexity 119.94
| epoch 3 |  iter 501 / 1327 | time 946[s] | perplexity 128.34
| epoch 3 |  iter 521 / 1327 | time 952[s] | perplexity 137.33
| epoch 3 |  iter 541 / 1327 | time 958[s] | perplexity 136.05
| epoch 3 |  iter 561 / 1327 | time 964[s] | perplexity 119.85
| epoch 3 |  iter 581 / 1327 | time 970[s] | perplexity 106.57
| epoch 3 |  iter 601 / 1327 | time 976[s] | perplexity 152.18
| epoch 3 |  iter 621 / 1327 | time 982[s] | perplexity 143.81
| epoch 3 |  iter 641 / 1327 | time 988[s] | perplexity 128.85
| epoch 3 |  iter 661 / 1327 | time 994[s] | perplexity 121.50
| epoch 3 |  iter 681 / 1327 | time 1000[s] | perplexity 99.55
| epoch 3 |  iter 701 / 1327 | time 1006[s] | perplexity 120.48
| epoch 3 |  iter 721 / 1327 | time 1011[s] | perplexity 126.11
| epoch 3 |  iter 741 / 1327 | time 1017[s] | perplexity 107.88
| epoch 3 |  iter 761 / 1327 | time 1023[s] | perplexity 103.60
| epoch 3 |  iter 781 / 1327 | time 1029[s] | perplexity 103.77
| epoch 3 |  iter 801 / 1327 | time 1035[s] | perplexity 114.90
| epoch 3 |  iter 821 / 1327 | time 1041[s] | perplexity 116.49
| epoch 3 |  iter 841 / 1327 | time 1047[s] | perplexity 114.93
| epoch 3 |  iter 861 / 1327 | time 1053[s] | perplexity 119.74
| epoch 3 |  iter 881 / 1327 | time 1059[s] | perplexity 106.93
| epoch 3 |  iter 901 / 1327 | time 1065[s] | perplexity 133.25
| epoch 3 |  iter 921 / 1327 | time 1071[s] | perplexity 118.78
| epoch 3 |  iter 941 / 1327 | time 1076[s] | perplexity 126.31
| epoch 3 |  iter 961 / 1327 | time 1082[s] | perplexity 133.07
| epoch 3 |  iter 981 / 1327 | time 1088[s] | perplexity 123.01
| epoch 3 |  iter 1001 / 1327 | time 1094[s] | perplexity 110.23
| epoch 3 |  iter 1021 / 1327 | time 1100[s] | perplexity 129.15
| epoch 3 |  iter 1041 / 1327 | time 1106[s] | perplexity 120.18
| epoch 3 |  iter 1061 / 1327 | time 1112[s] | perplexity 104.30
| epoch 3 |  iter 1081 / 1327 | time 1118[s] | perplexity 88.97
| epoch 3 |  iter 1101 / 1327 | time 1124[s] | perplexity 95.64
| epoch 3 |  iter 1121 / 1327 | time 1130[s] | perplexity 120.60
| epoch 3 |  iter 1141 / 1327 | time 1136[s] | perplexity 115.53
| epoch 3 |  iter 1161 / 1327 | time 1142[s] | perplexity 106.47
| epoch 3 |  iter 1181 / 1327 | time 1148[s] | perplexity 112.18
| epoch 3 |  iter 1201 / 1327 | time 1154[s] | perplexity 94.87
| epoch 3 |  iter 1221 / 1327 | time 1159[s] | perplexity 87.82
| epoch 3 |  iter 1241 / 1327 | time 1165[s] | perplexity 105.57
| epoch 3 |  iter 1261 / 1327 | time 1171[s] | perplexity 105.54
| epoch 3 |  iter 1281 / 1327 | time 1177[s] | perplexity 100.72
| epoch 3 |  iter 1301 / 1327 | time 1183[s] | perplexity 131.29
| epoch 3 |  iter 1321 / 1327 | time 1189[s] | perplexity 126.84
| epoch 4 |  iter 1 / 1327 | time 1191[s] | perplexity 134.18
| epoch 4 |  iter 21 / 1327 | time 1197[s] | perplexity 122.94
| epoch 4 |  iter 41 / 1327 | time 1203[s] | perplexity 107.72
| epoch 4 |  iter 61 / 1327 | time 1209[s] | perplexity 109.87
| epoch 4 |  iter 81 / 1327 | time 1215[s] | perplexity 97.44
| epoch 4 |  iter 101 / 1327 | time 1221[s] | perplexity 86.77
| epoch 4 |  iter 121 / 1327 | time 1227[s] | perplexity 95.40
| epoch 4 |  iter 141 / 1327 | time 1233[s] | perplexity 104.29
| epoch 4 |  iter 161 / 1327 | time 1239[s] | perplexity 119.42
| epoch 4 |  iter 181 / 1327 | time 1244[s] | perplexity 129.87
| epoch 4 |  iter 201 / 1327 | time 1250[s] | perplexity 120.89
| epoch 4 |  iter 221 / 1327 | time 1256[s] | perplexity 121.37
| epoch 4 |  iter 241 / 1327 | time 1262[s] | perplexity 114.78
| epoch 4 |  iter 261 / 1327 | time 1268[s] | perplexity 114.97
| epoch 4 |  iter 281 / 1327 | time 1274[s] | perplexity 122.07
| epoch 4 |  iter 301 / 1327 | time 1280[s] | perplexity 105.88
| epoch 4 |  iter 321 / 1327 | time 1286[s] | perplexity 84.09
| epoch 4 |  iter 341 / 1327 | time 1292[s] | perplexity 101.45
| epoch 4 |  iter 361 / 1327 | time 1298[s] | perplexity 131.19
| epoch 4 |  iter 381 / 1327 | time 1304[s] | perplexity 98.35
| epoch 4 |  iter 401 / 1327 | time 1310[s] | perplexity 111.70
| epoch 4 |  iter 421 / 1327 | time 1316[s] | perplexity 94.27
| epoch 4 |  iter 441 / 1327 | time 1322[s] | perplexity 103.22
| epoch 4 |  iter 461 / 1327 | time 1328[s] | perplexity 100.04
| epoch 4 |  iter 481 / 1327 | time 1333[s] | perplexity 103.83
| epoch 4 |  iter 501 / 1327 | time 1339[s] | perplexity 107.87
| epoch 4 |  iter 521 / 1327 | time 1345[s] | perplexity 117.03
| epoch 4 |  iter 541 / 1327 | time 1351[s] | perplexity 114.25
| epoch 4 |  iter 561 / 1327 | time 1357[s] | perplexity 103.34
| epoch 4 |  iter 581 / 1327 | time 1363[s] | perplexity 89.92
| epoch 4 |  iter 601 / 1327 | time 1369[s] | perplexity 129.68
| epoch 4 |  iter 621 / 1327 | time 1375[s] | perplexity 123.08
| epoch 4 |  iter 641 / 1327 | time 1381[s] | perplexity 111.31
| epoch 4 |  iter 661 / 1327 | time 1387[s] | perplexity 104.01
| epoch 4 |  iter 681 / 1327 | time 1393[s] | perplexity 84.54
| epoch 4 |  iter 701 / 1327 | time 1399[s] | perplexity 103.82
| epoch 4 |  iter 721 / 1327 | time 1404[s] | perplexity 108.22
| epoch 4 |  iter 741 / 1327 | time 1410[s] | perplexity 96.01
| epoch 4 |  iter 761 / 1327 | time 1416[s] | perplexity 88.78
| epoch 4 |  iter 781 / 1327 | time 1422[s] | perplexity 88.47
| epoch 4 |  iter 801 / 1327 | time 1428[s] | perplexity 97.88
| epoch 4 |  iter 821 / 1327 | time 1434[s] | perplexity 103.32
| epoch 4 |  iter 841 / 1327 | time 1440[s] | perplexity 99.11
| epoch 4 |  iter 861 / 1327 | time 1446[s] | perplexity 104.41
| epoch 4 |  iter 881 / 1327 | time 1452[s] | perplexity 91.42
| epoch 4 |  iter 901 / 1327 | time 1458[s] | perplexity 115.78
| epoch 4 |  iter 921 / 1327 | time 1464[s] | perplexity 103.37
| epoch 4 |  iter 941 / 1327 | time 1470[s] | perplexity 111.49
| epoch 4 |  iter 961 / 1327 | time 1476[s] | perplexity 112.74
| epoch 4 |  iter 981 / 1327 | time 1481[s] | perplexity 106.57
| epoch 4 |  iter 1001 / 1327 | time 1487[s] | perplexity 98.34
| epoch 4 |  iter 1021 / 1327 | time 1493[s] | perplexity 114.28
| epoch 4 |  iter 1041 / 1327 | time 1499[s] | perplexity 104.62
| epoch 4 |  iter 1061 / 1327 | time 1505[s] | perplexity 90.57
| epoch 4 |  iter 1081 / 1327 | time 1511[s] | perplexity 78.93
| epoch 4 |  iter 1101 / 1327 | time 1517[s] | perplexity 79.35
| epoch 4 |  iter 1121 / 1327 | time 1523[s] | perplexity 103.15
| epoch 4 |  iter 1141 / 1327 | time 1529[s] | perplexity 100.01
| epoch 4 |  iter 1161 / 1327 | time 1535[s] | perplexity 92.35
| epoch 4 |  iter 1181 / 1327 | time 1541[s] | perplexity 96.91
| epoch 4 |  iter 1201 / 1327 | time 1547[s] | perplexity 84.74
| epoch 4 |  iter 1221 / 1327 | time 1553[s] | perplexity 75.15
| epoch 4 |  iter 1241 / 1327 | time 1558[s] | perplexity 91.74
| epoch 4 |  iter 1261 / 1327 | time 1564[s] | perplexity 94.50
| epoch 4 |  iter 1281 / 1327 | time 1570[s] | perplexity 89.46
| epoch 4 |  iter 1301 / 1327 | time 1576[s] | perplexity 111.59
| epoch 4 |  iter 1321 / 1327 | time 1582[s] | perplexity 110.69

8コアくらいで実行してみる

%python3

from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb

# hyper parameters
batch_size = 20
wordvec_size = 100
hidden_size = 100
time_size = 35
lr = 20.0
max_epoch = 4
max_grad = 0.25

# load dataset
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

# generate model
model = Rnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20)
trainer.plot(ylim=(0, 500))

model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)

model.save_params()
Downloading ptb.train.txt ... 
Done
Downloading ptb.test.txt ... 
Done
| epoch 1 |  iter 1 / 1327 | time 0[s] | perplexity 10000.10
| epoch 1 |  iter 21 / 1327 | time 7[s] | perplexity 2852.60
| epoch 1 |  iter 41 / 1327 | time 13[s] | perplexity 1197.79
| epoch 1 |  iter 61 / 1327 | time 20[s] | perplexity 972.65
| epoch 1 |  iter 81 / 1327 | time 26[s] | perplexity 801.79
| epoch 1 |  iter 101 / 1327 | time 33[s] | perplexity 654.38
| epoch 1 |  iter 121 / 1327 | time 40[s] | perplexity 650.02
| epoch 1 |  iter 141 / 1327 | time 46[s] | perplexity 586.48
| epoch 1 |  iter 161 / 1327 | time 53[s] | perplexity 575.72
| epoch 1 |  iter 181 / 1327 | time 60[s] | perplexity 575.51
| epoch 1 |  iter 201 / 1327 | time 66[s] | perplexity 500.62
| epoch 1 |  iter 221 / 1327 | time 73[s] | perplexity 489.72
| epoch 1 |  iter 241 / 1327 | time 79[s] | perplexity 440.72
| epoch 1 |  iter 261 / 1327 | time 86[s] | perplexity 464.75
| epoch 1 |  iter 281 / 1327 | time 93[s] | perplexity 458.81
| epoch 1 |  iter 301 / 1327 | time 99[s] | perplexity 391.53
| epoch 1 |  iter 321 / 1327 | time 106[s] | perplexity 352.07
| epoch 1 |  iter 341 / 1327 | time 112[s] | perplexity 402.92
| epoch 1 |  iter 361 / 1327 | time 119[s] | perplexity 408.88
| epoch 1 |  iter 381 / 1327 | time 125[s] | perplexity 334.41
| epoch 1 |  iter 401 / 1327 | time 132[s] | perplexity 356.17
| epoch 1 |  iter 421 / 1327 | time 139[s] | perplexity 346.59
| epoch 1 |  iter 441 / 1327 | time 145[s] | perplexity 325.49
| epoch 1 |  iter 461 / 1327 | time 152[s] | perplexity 329.53
| epoch 1 |  iter 481 / 1327 | time 159[s] | perplexity 312.91
| epoch 1 |  iter 501 / 1327 | time 166[s] | perplexity 318.56
| epoch 1 |  iter 521 / 1327 | time 172[s] | perplexity 304.97
| epoch 1 |  iter 541 / 1327 | time 179[s] | perplexity 315.58
| epoch 1 |  iter 561 / 1327 | time 186[s] | perplexity 285.87
| epoch 1 |  iter 581 / 1327 | time 192[s] | perplexity 262.11
| epoch 1 |  iter 601 / 1327 | time 199[s] | perplexity 336.64
| epoch 1 |  iter 621 / 1327 | time 206[s] | perplexity 312.01
| epoch 1 |  iter 641 / 1327 | time 212[s] | perplexity 286.62
| epoch 1 |  iter 661 / 1327 | time 219[s] | perplexity 273.06
| epoch 1 |  iter 681 / 1327 | time 226[s] | perplexity 227.11
| epoch 1 |  iter 701 / 1327 | time 232[s] | perplexity 250.40
| epoch 1 |  iter 721 / 1327 | time 239[s] | perplexity 262.76
| epoch 1 |  iter 741 / 1327 | time 246[s] | perplexity 223.73
| epoch 1 |  iter 761 / 1327 | time 252[s] | perplexity 232.59
| epoch 1 |  iter 781 / 1327 | time 259[s] | perplexity 220.59
| epoch 1 |  iter 801 / 1327 | time 266[s] | perplexity 244.55
| epoch 1 |  iter 821 / 1327 | time 272[s] | perplexity 226.35
| epoch 1 |  iter 841 / 1327 | time 279[s] | perplexity 229.13
| epoch 1 |  iter 861 / 1327 | time 286[s] | perplexity 221.18
| epoch 1 |  iter 881 / 1327 | time 293[s] | perplexity 208.49
| epoch 1 |  iter 901 / 1327 | time 299[s] | perplexity 253.19
| epoch 1 |  iter 921 / 1327 | time 306[s] | perplexity 230.20
| epoch 1 |  iter 941 / 1327 | time 313[s] | perplexity 231.52
| epoch 1 |  iter 961 / 1327 | time 319[s] | perplexity 244.30
| epoch 1 |  iter 981 / 1327 | time 326[s] | perplexity 230.86
| epoch 1 |  iter 1001 / 1327 | time 332[s] | perplexity 194.41
| epoch 1 |  iter 1021 / 1327 | time 339[s] | perplexity 226.75
| epoch 1 |  iter 1041 / 1327 | time 346[s] | perplexity 208.89
| epoch 1 |  iter 1061 / 1327 | time 352[s] | perplexity 197.86
| epoch 1 |  iter 1081 / 1327 | time 359[s] | perplexity 170.05
| epoch 1 |  iter 1101 / 1327 | time 366[s] | perplexity 194.50
| epoch 1 |  iter 1121 / 1327 | time 372[s] | perplexity 228.69
| epoch 1 |  iter 1141 / 1327 | time 379[s] | perplexity 209.67
| epoch 1 |  iter 1161 / 1327 | time 386[s] | perplexity 200.76
| epoch 1 |  iter 1181 / 1327 | time 392[s] | perplexity 192.34
| epoch 1 |  iter 1201 / 1327 | time 399[s] | perplexity 165.02
| epoch 1 |  iter 1221 / 1327 | time 406[s] | perplexity 159.78
| epoch 1 |  iter 1241 / 1327 | time 412[s] | perplexity 189.42
| epoch 1 |  iter 1261 / 1327 | time 419[s] | perplexity 172.00
| epoch 1 |  iter 1281 / 1327 | time 426[s] | perplexity 180.13
| epoch 1 |  iter 1301 / 1327 | time 433[s] | perplexity 224.04
| epoch 1 |  iter 1321 / 1327 | time 439[s] | perplexity 212.56
| epoch 2 |  iter 1 / 1327 | time 442[s] | perplexity 227.27
| epoch 2 |  iter 21 / 1327 | time 448[s] | perplexity 205.48
| epoch 2 |  iter 41 / 1327 | time 455[s] | perplexity 192.19
| epoch 2 |  iter 61 / 1327 | time 462[s] | perplexity 178.33
| epoch 2 |  iter 81 / 1327 | time 469[s] | perplexity 160.51
| epoch 2 |  iter 101 / 1327 | time 475[s] | perplexity 153.43
| epoch 2 |  iter 121 / 1327 | time 482[s] | perplexity 161.21
| epoch 2 |  iter 141 / 1327 | time 489[s] | perplexity 179.45
| epoch 2 |  iter 161 / 1327 | time 496[s] | perplexity 194.39
| epoch 2 |  iter 181 / 1327 | time 502[s] | perplexity 203.20
| epoch 2 |  iter 201 / 1327 | time 509[s] | perplexity 187.49
| epoch 2 |  iter 221 / 1327 | time 516[s] | perplexity 184.57
| epoch 2 |  iter 241 / 1327 | time 522[s] | perplexity 179.09
| epoch 2 |  iter 261 / 1327 | time 529[s] | perplexity 187.89
| epoch 2 |  iter 281 / 1327 | time 536[s] | perplexity 186.25
| epoch 2 |  iter 301 / 1327 | time 543[s] | perplexity 167.91
| epoch 2 |  iter 321 / 1327 | time 549[s] | perplexity 138.67
| epoch 2 |  iter 341 / 1327 | time 556[s] | perplexity 172.03
| epoch 2 |  iter 361 / 1327 | time 563[s] | perplexity 198.36
| epoch 2 |  iter 381 / 1327 | time 570[s] | perplexity 155.07
| epoch 2 |  iter 401 / 1327 | time 576[s] | perplexity 169.08
| epoch 2 |  iter 421 / 1327 | time 583[s] | perplexity 158.41
| epoch 2 |  iter 441 / 1327 | time 590[s] | perplexity 165.13
| epoch 2 |  iter 461 / 1327 | time 596[s] | perplexity 159.82
| epoch 2 |  iter 481 / 1327 | time 603[s] | perplexity 158.04
| epoch 2 |  iter 501 / 1327 | time 610[s] | perplexity 171.63
| epoch 2 |  iter 521 / 1327 | time 616[s] | perplexity 173.55
| epoch 2 |  iter 541 / 1327 | time 623[s] | perplexity 176.24
| epoch 2 |  iter 561 / 1327 | time 630[s] | perplexity 156.97
| epoch 2 |  iter 581 / 1327 | time 637[s] | perplexity 139.85
| epoch 2 |  iter 601 / 1327 | time 643[s] | perplexity 192.27
| epoch 2 |  iter 621 / 1327 | time 651[s] | perplexity 183.50
| epoch 2 |  iter 641 / 1327 | time 658[s] | perplexity 167.28
| epoch 2 |  iter 661 / 1327 | time 664[s] | perplexity 154.80
| epoch 2 |  iter 681 / 1327 | time 671[s] | perplexity 131.82
| epoch 2 |  iter 701 / 1327 | time 678[s] | perplexity 152.01
| epoch 2 |  iter 721 / 1327 | time 685[s] | perplexity 161.30
| epoch 2 |  iter 741 / 1327 | time 692[s] | perplexity 136.03
| epoch 2 |  iter 761 / 1327 | time 698[s] | perplexity 131.99
| epoch 2 |  iter 781 / 1327 | time 705[s] | perplexity 137.36
| epoch 2 |  iter 801 / 1327 | time 712[s] | perplexity 148.77
| epoch 2 |  iter 821 / 1327 | time 718[s] | perplexity 146.65
| epoch 2 |  iter 841 / 1327 | time 725[s] | perplexity 145.19
| epoch 2 |  iter 861 / 1327 | time 732[s] | perplexity 147.84
| epoch 2 |  iter 881 / 1327 | time 739[s] | perplexity 131.78
| epoch 2 |  iter 901 / 1327 | time 745[s] | perplexity 167.73
| epoch 2 |  iter 921 / 1327 | time 752[s] | perplexity 148.08
| epoch 2 |  iter 941 / 1327 | time 759[s] | perplexity 156.33
| epoch 2 |  iter 961 / 1327 | time 765[s] | perplexity 166.75
| epoch 2 |  iter 981 / 1327 | time 772[s] | perplexity 155.06
| epoch 2 |  iter 1001 / 1327 | time 779[s] | perplexity 132.72
| epoch 2 |  iter 1021 / 1327 | time 785[s] | perplexity 156.88
| epoch 2 |  iter 1041 / 1327 | time 792[s] | perplexity 143.86
| epoch 2 |  iter 1061 / 1327 | time 799[s] | perplexity 131.46
| epoch 2 |  iter 1081 / 1327 | time 805[s] | perplexity 112.57
| epoch 2 |  iter 1101 / 1327 | time 812[s] | perplexity 122.22
| epoch 2 |  iter 1121 / 1327 | time 819[s] | perplexity 156.04
| epoch 2 |  iter 1141 / 1327 | time 826[s] | perplexity 143.46
| epoch 2 |  iter 1161 / 1327 | time 833[s] | perplexity 135.22
| epoch 2 |  iter 1181 / 1327 | time 840[s] | perplexity 136.76
| epoch 2 |  iter 1201 / 1327 | time 846[s] | perplexity 114.39
| epoch 2 |  iter 1221 / 1327 | time 853[s] | perplexity 110.84
| epoch 2 |  iter 1241 / 1327 | time 860[s] | perplexity 131.14
| epoch 2 |  iter 1261 / 1327 | time 867[s] | perplexity 124.96
| epoch 2 |  iter 1281 / 1327 | time 874[s] | perplexity 124.53
| epoch 2 |  iter 1301 / 1327 | time 881[s] | perplexity 159.41
| epoch 2 |  iter 1321 / 1327 | time 887[s] | perplexity 155.05
| epoch 3 |  iter 1 / 1327 | time 890[s] | perplexity 162.31
| epoch 3 |  iter 21 / 1327 | time 897[s] | perplexity 145.83
| epoch 3 |  iter 41 / 1327 | time 904[s] | perplexity 137.93
| epoch 3 |  iter 61 / 1327 | time 910[s] | perplexity 129.50
| epoch 3 |  iter 81 / 1327 | time 917[s] | perplexity 118.11
| epoch 3 |  iter 101 / 1327 | time 924[s] | perplexity 106.62
| epoch 3 |  iter 121 / 1327 | time 931[s] | perplexity 116.58
| epoch 3 |  iter 141 / 1327 | time 937[s] | perplexity 128.34
| epoch 3 |  iter 161 / 1327 | time 944[s] | perplexity 145.42
| epoch 3 |  iter 181 / 1327 | time 951[s] | perplexity 152.55
| epoch 3 |  iter 201 / 1327 | time 958[s] | perplexity 143.41
| epoch 3 |  iter 221 / 1327 | time 964[s] | perplexity 142.29
| epoch 3 |  iter 241 / 1327 | time 971[s] | perplexity 136.47
| epoch 3 |  iter 261 / 1327 | time 978[s] | perplexity 141.66
| epoch 3 |  iter 281 / 1327 | time 984[s] | perplexity 142.91
| epoch 3 |  iter 301 / 1327 | time 991[s] | perplexity 125.87
| epoch 3 |  iter 321 / 1327 | time 998[s] | perplexity 104.19
| epoch 3 |  iter 341 / 1327 | time 1004[s] | perplexity 128.50
| epoch 3 |  iter 361 / 1327 | time 1011[s] | perplexity 154.88
| epoch 3 |  iter 381 / 1327 | time 1018[s] | perplexity 117.23
| epoch 3 |  iter 401 / 1327 | time 1025[s] | perplexity 130.37
| epoch 3 |  iter 421 / 1327 | time 1031[s] | perplexity 116.28
| epoch 3 |  iter 441 / 1327 | time 1038[s] | perplexity 126.29
| epoch 3 |  iter 461 / 1327 | time 1045[s] | perplexity 120.17
| epoch 3 |  iter 481 / 1327 | time 1051[s] | perplexity 121.45
| epoch 3 |  iter 501 / 1327 | time 1058[s] | perplexity 132.70
| epoch 3 |  iter 521 / 1327 | time 1065[s] | perplexity 138.81
| epoch 3 |  iter 541 / 1327 | time 1072[s] | perplexity 139.54
| epoch 3 |  iter 561 / 1327 | time 1078[s] | perplexity 120.78
| epoch 3 |  iter 581 / 1327 | time 1085[s] | perplexity 106.74
| epoch 3 |  iter 601 / 1327 | time 1092[s] | perplexity 151.46
| epoch 3 |  iter 621 / 1327 | time 1099[s] | perplexity 146.41
| epoch 3 |  iter 641 / 1327 | time 1105[s] | perplexity 132.30
| epoch 3 |  iter 661 / 1327 | time 1112[s] | perplexity 120.47
| epoch 3 |  iter 681 / 1327 | time 1119[s] | perplexity 101.10
| epoch 3 |  iter 701 / 1327 | time 1125[s] | perplexity 120.23
| epoch 3 |  iter 721 / 1327 | time 1132[s] | perplexity 127.58
| epoch 3 |  iter 741 / 1327 | time 1139[s] | perplexity 110.48
| epoch 3 |  iter 761 / 1327 | time 1146[s] | perplexity 104.48
| epoch 3 |  iter 781 / 1327 | time 1152[s] | perplexity 106.76
| epoch 3 |  iter 801 / 1327 | time 1159[s] | perplexity 116.66
| epoch 3 |  iter 821 / 1327 | time 1166[s] | perplexity 120.04
| epoch 3 |  iter 841 / 1327 | time 1173[s] | perplexity 115.68
| epoch 3 |  iter 861 / 1327 | time 1179[s] | perplexity 122.05
| epoch 3 |  iter 881 / 1327 | time 1186[s] | perplexity 108.33
| epoch 3 |  iter 901 / 1327 | time 1193[s] | perplexity 134.20
| epoch 3 |  iter 921 / 1327 | time 1200[s] | perplexity 120.01
| epoch 3 |  iter 941 / 1327 | time 1206[s] | perplexity 130.07
| epoch 3 |  iter 961 / 1327 | time 1213[s] | perplexity 133.97
| epoch 3 |  iter 981 / 1327 | time 1220[s] | perplexity 125.72
| epoch 3 |  iter 1001 / 1327 | time 1227[s] | perplexity 110.70
| epoch 3 |  iter 1021 / 1327 | time 1233[s] | perplexity 128.98
| epoch 3 |  iter 1041 / 1327 | time 1240[s] | perplexity 121.18
| epoch 3 |  iter 1061 / 1327 | time 1247[s] | perplexity 104.08
| epoch 3 |  iter 1081 / 1327 | time 1254[s] | perplexity 90.11
| epoch 3 |  iter 1101 / 1327 | time 1260[s] | perplexity 96.71
| epoch 3 |  iter 1121 / 1327 | time 1267[s] | perplexity 123.43
| epoch 3 |  iter 1141 / 1327 | time 1274[s] | perplexity 116.54
| epoch 3 |  iter 1161 / 1327 | time 1281[s] | perplexity 107.55
| epoch 3 |  iter 1181 / 1327 | time 1287[s] | perplexity 113.06
| epoch 3 |  iter 1201 / 1327 | time 1294[s] | perplexity 95.29
| epoch 3 |  iter 1221 / 1327 | time 1301[s] | perplexity 89.38
| epoch 3 |  iter 1241 / 1327 | time 1308[s] | perplexity 106.44
| epoch 3 |  iter 1261 / 1327 | time 1314[s] | perplexity 106.48
| epoch 3 |  iter 1281 / 1327 | time 1321[s] | perplexity 101.53
| epoch 3 |  iter 1301 / 1327 | time 1328[s] | perplexity 130.48
| epoch 3 |  iter 1321 / 1327 | time 1335[s] | perplexity 129.24
| epoch 4 |  iter 1 / 1327 | time 1337[s] | perplexity 137.08
| epoch 4 |  iter 21 / 1327 | time 1344[s] | perplexity 123.39
| epoch 4 |  iter 41 / 1327 | time 1350[s] | perplexity 109.46
| epoch 4 |  iter 61 / 1327 | time 1357[s] | perplexity 109.92
| epoch 4 |  iter 81 / 1327 | time 1364[s] | perplexity 98.42
| epoch 4 |  iter 101 / 1327 | time 1371[s] | perplexity 87.49
| epoch 4 |  iter 121 / 1327 | time 1377[s] | perplexity 96.47
| epoch 4 |  iter 141 / 1327 | time 1384[s] | perplexity 104.20
| epoch 4 |  iter 161 / 1327 | time 1391[s] | perplexity 120.35
| epoch 4 |  iter 181 / 1327 | time 1397[s] | perplexity 131.39
| epoch 4 |  iter 201 / 1327 | time 1404[s] | perplexity 122.30
| epoch 4 |  iter 221 / 1327 | time 1411[s] | perplexity 124.78
| epoch 4 |  iter 241 / 1327 | time 1418[s] | perplexity 117.55
| epoch 4 |  iter 261 / 1327 | time 1424[s] | perplexity 116.76
| epoch 4 |  iter 281 / 1327 | time 1431[s] | perplexity 122.54
| epoch 4 |  iter 301 / 1327 | time 1438[s] | perplexity 105.78
| epoch 4 |  iter 321 / 1327 | time 1444[s] | perplexity 85.61
| epoch 4 |  iter 341 / 1327 | time 1451[s] | perplexity 103.02
| epoch 4 |  iter 361 / 1327 | time 1458[s] | perplexity 128.83
| epoch 4 |  iter 381 / 1327 | time 1464[s] | perplexity 99.55
| epoch 4 |  iter 401 / 1327 | time 1471[s] | perplexity 111.82
| epoch 4 |  iter 421 / 1327 | time 1478[s] | perplexity 95.68
| epoch 4 |  iter 441 / 1327 | time 1485[s] | perplexity 104.78
| epoch 4 |  iter 461 / 1327 | time 1492[s] | perplexity 102.15
| epoch 4 |  iter 481 / 1327 | time 1499[s] | perplexity 104.22
| epoch 4 |  iter 501 / 1327 | time 1506[s] | perplexity 111.26
| epoch 4 |  iter 521 / 1327 | time 1512[s] | perplexity 119.01
| epoch 4 |  iter 541 / 1327 | time 1519[s] | perplexity 115.68
| epoch 4 |  iter 561 / 1327 | time 1526[s] | perplexity 105.81
| epoch 4 |  iter 581 / 1327 | time 1533[s] | perplexity 89.73
| epoch 4 |  iter 601 / 1327 | time 1540[s] | perplexity 127.60
| epoch 4 |  iter 621 / 1327 | time 1547[s] | perplexity 124.38
| epoch 4 |  iter 641 / 1327 | time 1554[s] | perplexity 112.94
| epoch 4 |  iter 661 / 1327 | time 1560[s] | perplexity 104.02
| epoch 4 |  iter 681 / 1327 | time 1567[s] | perplexity 85.23
| epoch 4 |  iter 701 / 1327 | time 1574[s] | perplexity 104.08
| epoch 4 |  iter 721 / 1327 | time 1581[s] | perplexity 108.60
| epoch 4 |  iter 741 / 1327 | time 1587[s] | perplexity 98.33
| epoch 4 |  iter 761 / 1327 | time 1594[s] | perplexity 89.41
| epoch 4 |  iter 781 / 1327 | time 1601[s] | perplexity 88.64
| epoch 4 |  iter 801 / 1327 | time 1608[s] | perplexity 99.77
| epoch 4 |  iter 821 / 1327 | time 1614[s] | perplexity 105.46
| epoch 4 |  iter 841 / 1327 | time 1621[s] | perplexity 100.29
| epoch 4 |  iter 861 / 1327 | time 1628[s] | perplexity 106.53
| epoch 4 |  iter 881 / 1327 | time 1634[s] | perplexity 92.96
| epoch 4 |  iter 901 / 1327 | time 1641[s] | perplexity 118.05
| epoch 4 |  iter 921 / 1327 | time 1648[s] | perplexity 103.90
| epoch 4 |  iter 941 / 1327 | time 1654[s] | perplexity 113.97
| epoch 4 |  iter 961 / 1327 | time 1661[s] | perplexity 114.04
| epoch 4 |  iter 981 / 1327 | time 1668[s] | perplexity 108.00
| epoch 4 |  iter 1001 / 1327 | time 1675[s] | perplexity 98.89
| epoch 4 |  iter 1021 / 1327 | time 1681[s] | perplexity 113.60
| epoch 4 |  iter 1041 / 1327 | time 1688[s] | perplexity 106.50
| epoch 4 |  iter 1061 / 1327 | time 1695[s] | perplexity 89.44
| epoch 4 |  iter 1081 / 1327 | time 1702[s] | perplexity 79.99
| epoch 4 |  iter 1101 / 1327 | time 1708[s] | perplexity 81.84
| epoch 4 |  iter 1121 / 1327 | time 1715[s] | perplexity 104.83
| epoch 4 |  iter 1141 / 1327 | time 1722[s] | perplexity 102.74
| epoch 4 |  iter 1161 / 1327 | time 1728[s] | perplexity 93.97
| epoch 4 |  iter 1181 / 1327 | time 1735[s] | perplexity 98.14
| epoch 4 |  iter 1201 / 1327 | time 1742[s] | perplexity 84.93
| epoch 4 |  iter 1221 / 1327 | time 1749[s] | perplexity 76.64
| epoch 4 |  iter 1241 / 1327 | time 1755[s] | perplexity 93.11
| epoch 4 |  iter 1261 / 1327 | time 1762[s] | perplexity 95.51
| epoch 4 |  iter 1281 / 1327 | time 1769[s] | perplexity 90.25
| epoch 4 |  iter 1301 / 1327 | time 1776[s] | perplexity 112.19
| epoch 4 |  iter 1321 / 1327 | time 1782[s] | perplexity 111.55
evaluating perplexity ...

0 / 235
1 / 235
2 / 235
3 / 235
4 / 235
5 / 235
6 / 235
7 / 235
8 / 235
9 / 235
10 / 235
11 / 235
12 / 235
13 / 235
14 / 235
15 / 235
16 / 235
17 / 235
18 / 235
19 / 235
20 / 235
21 / 235
22 / 235
23 / 235
24 / 235
25 / 235
26 / 235
27 / 235
28 / 235
29 / 235
30 / 235
31 / 235
32 / 235
33 / 235
34 / 235
35 / 235
36 / 235
37 / 235
38 / 235
39 / 235
40 / 235
41 / 235
42 / 235
43 / 235
44 / 235
45 / 235
46 / 235
47 / 235
48 / 235
49 / 235
50 / 235
51 / 235
52 / 235
53 / 235
54 / 235
55 / 235
56 / 235
57 / 235
58 / 235
59 / 235
60 / 235
61 / 235
62 / 235
63 / 235
64 / 235
65 / 235
66 / 235
67 / 235
68 / 235
69 / 235
70 / 235
71 / 235
72 / 235
73 / 235
74 / 235
75 / 235
76 / 235
77 / 235
78 / 235
79 / 235
80 / 235
81 / 235
82 / 235
83 / 235
84 / 235
85 / 235
86 / 235
87 / 235
88 / 235
89 / 235
90 / 235
91 / 235
92 / 235
93 / 235
94 / 235
95 / 235
96 / 235
97 / 235
98 / 235
99 / 235
100 / 235
101 / 235
102 / 235
103 / 235
104 / 235
105 / 235
106 / 235
107 / 235
108 / 235
109 / 235
110 / 235
111 / 235
112 / 235
113 / 235
114 / 235
115 / 235
116 / 235
117 / 235
118 / 235
119 / 235
120 / 235
121 / 235
122 / 235
123 / 235
124 / 235
125 / 235
126 / 235
127 / 235
128 / 235
129 / 235
130 / 235
131 / 235
132 / 235
133 / 235
134 / 235
135 / 235
136 / 235
137 / 235
138 / 235
139 / 235
140 / 235
141 / 235
142 / 235
143 / 235
144 / 235
145 / 235
146 / 235
147 / 235
148 / 235
149 / 235
150 / 235
151 / 235
152 / 235
153 / 235
154 / 235
155 / 235
156 / 235
157 / 235
158 / 235
159 / 235
160 / 235
161 / 235
162 / 235
163 / 235
164 / 235
165 / 235
166 / 235
167 / 235
168 / 235
169 / 235
170 / 235
171 / 235
172 / 235
173 / 235
174 / 235
175 / 235
176 / 235
177 / 235
178 / 235
179 / 235
180 / 235
181 / 235
182 / 235
183 / 235
184 / 235
185 / 235
186 / 235
187 / 235
188 / 235
189 / 235
190 / 235
191 / 235
192 / 235
193 / 235
194 / 235
195 / 235
196 / 235
197 / 235
198 / 235
199 / 235
200 / 235
201 / 235
202 / 235
203 / 235
204 / 235
205 / 235
206 / 235
207 / 235
208 / 235
209 / 235
210 / 235
211 / 235
212 / 235
213 / 235
214 / 235
215 / 235
216 / 235
217 / 235
218 / 235
219 / 235
220 / 235
221 / 235
222 / 235
223 / 235
224 / 235
225 / 235
226 / 235
227 / 235
228 / 235
229 / 235
230 / 235
231 / 235
232 / 235
233 / 235
234 / 235
test perplexity:  137.9052746711987

%sh
top -b -n1
top - 11:21:03 up 14 min,  1 user,  load average: 7.33, 3.84, 1.59
Tasks: 147 total,   2 running,  92 sleeping,   0 stopped,   0 zombie
Cpu(s): 11.4%us, 11.9%sy,  0.0%ni, 76.1%id,  0.1%wa,  0.0%hi,  0.0%si,  0.5%st
Mem:  15394144k total,  3463012k used, 11931132k free,    34120k buffers
Swap:        0k total,        0k used,        0k free,  2005276k cached

  PID USER      PR  NI  VIRT  RES  SHR S %CPU %MEM    TIME+  COMMAND            
 3676 ec2-user  20   0 1058m 408m  18m R 655.4  2.7  25:58.79 python3           
 2604 root      20   0  6520   96    0 S  2.0  0.0   0:00.22 rngd               
 3403 ec2-user  20   0 6695m 557m  25m S  2.0  3.7   0:31.42 java               
 3483 ec2-user  20   0 4869m  91m  18m S  2.0  0.6   0:02.16 java               
 3634 ec2-user  20   0 5195m 115m  18m S  2.0  0.8   0:02.97 java               
 3782 ec2-user  20   0 15360 2292 1960 R  2.0  0.0   0:00.01 top                
    1 root      20   0 19688 2544 2212 S  0.0  0.0   0:01.17 init               
    2 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kthreadd           
    3 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/0:0        
    4 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/0:0H       
    6 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 mm_percpu_wq       
    7 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/0        
    8 root      20   0     0    0    0 I  0.0  0.0   0:00.09 rcu_sched          
    9 root      20   0     0    0    0 I  0.0  0.0   0:00.00 rcu_bh             
   10 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/0        
   11 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/0         
   12 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/0            
   13 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/1            
   14 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/1         
   15 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/1        
   16 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/1        
   17 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/1:0        
   18 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/1:0H       
   19 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/2            
   20 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/2         
   21 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/2        
   22 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/2        
   23 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/2:0        
   24 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/2:0H       
   25 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/3            
   26 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/3         
   27 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/3        
   28 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/3        
   29 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/3:0        
   30 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/3:0H       
   31 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/4            
   32 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/4         
   33 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/4        
   34 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/4        
   35 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/4:0        
   36 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/4:0H       
   37 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/5            
   38 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/5         
   39 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/5        
   40 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/5        
   41 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/5:0        
   42 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/5:0H       
   43 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/6            
   44 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/6         
   45 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/6        
   46 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/6        
   47 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/6:0        
   48 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/6:0H       
   49 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/7            
   50 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/7         
   51 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/7        
   52 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/7        
   53 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/7:0        
   54 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/7:0H       
   55 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kdevtmpfs          
   56 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 netns              
   57 root      20   0     0    0    0 I  0.0  0.0   0:00.47 kworker/u30:1      
   63 root      20   0     0    0    0 S  0.0  0.0   0:00.00 xenbus             
   64 root      20   0     0    0    0 S  0.0  0.0   0:00.01 xenwatch           
   65 root      20   0     0    0    0 I  0.0  0.0   0:00.02 kworker/0:1        
   74 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/2:1        
   75 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/3:1        
   99 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/5:1        
  107 root      20   0     0    0    0 I  0.0  0.0   0:00.22 kworker/u30:5      
  161 root      20   0     0    0    0 I  0.0  0.0   0:00.05 kworker/7:1        
  218 root      20   0     0    0    0 S  0.0  0.0   0:00.00 khungtaskd         
  219 root      20   0     0    0    0 S  0.0  0.0   0:00.00 oom_reaper         
  220 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 writeback          
  221 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kcompactd0         
  223 root      25   5     0    0    0 S  0.0  0.0   0:00.00 ksmd               
  224 root      39  19     0    0    0 S  0.0  0.0   0:00.00 khugepaged         
  225 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 crypto             
  226 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kintegrityd        
  228 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kblockd            
  238 root      20   0     0    0    0 I  0.0  0.0   0:00.02 kworker/6:1        
  581 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 md                 
  584 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 edac-poller        
  590 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/4:1        
  728 root      20   0     0    0    0 S  0.0  0.0   0:00.02 kauditd            
  734 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kswapd0            
  830 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kthrotld           
  890 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kstrp              
 1680 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ata_sff            
 1693 root      20   0     0    0    0 S  0.0  0.0   0:00.00 scsi_eh_0          
 1694 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 scsi_tmf_0         
 1697 root      20   0     0    0    0 S  0.0  0.0   0:00.00 scsi_eh_1          
 1698 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 scsi_tmf_1         
 1749 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/u31:0      
 1770 root      20   0     0    0    0 S  0.0  0.0   0:00.04 jbd2/xvda1-8       
 1771 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ext4-rsv-conver    
 1799 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/0:1H       
 1817 root      20   0 11476 2652 1752 S  0.0  0.0   0:00.07 udevd              
 1962 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ixgbevf            
 2076 root      20   0     0    0    0 I  0.0  0.0   0:00.01 kworker/1:2        
 2158 root      20   0  106m  620  388 S  0.0  0.0   0:00.00 lvmetad            
 2167 root      20   0 27196  200    4 S  0.0  0.0   0:00.00 lvmpolld           
 2219 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ipv6_addrconf      
 2356 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/5:1H       
 2378 root      20   0  9412 1304  816 S  0.0  0.0   0:00.00 dhclient           
 2493 root      20   0  9412 1920 1444 S  0.0  0.0   0:00.00 dhclient           
 2535 root      20   0  615m  15m  11m S  0.0  0.1   0:00.09 amazon-ssm-agen    
 2547 root      16  -4 53004 2408 2024 S  0.0  0.0   0:00.00 auditd             
 2569 root      20   0  241m 2828 2444 S  0.0  0.0   0:00.01 rsyslogd           
 2590 root      20   0 92900 2576 2312 S  0.0  0.0   0:00.06 irqbalance         
 2622 rpc       20   0 35364 2264 1868 S  0.0  0.0   0:00.01 rpcbind            
 2643 rpcuser   20   0 39932 3300 2500 S  0.0  0.0   0:00.00 rpc.statd          
 2674 dbus      20   0 21844  232    0 S  0.0  0.0   0:00.00 dbus-daemon        
 2709 root      20   0  4396 1400 1260 S  0.0  0.0   0:00.00 acpid              
 2805 root      20   0 80588 2664 1840 S  0.0  0.0   0:00.00 sshd               
 2816 ntp       20   0  113m 5264 4456 S  0.0  0.0   0:00.06 ntpd               
 2837 root      20   0 89628 3888 2168 S  0.0  0.0   0:00.01 sendmail           
 2846 smmsp     20   0 81088 3904 2424 S  0.0  0.0   0:00.00 sendmail           
 2858 root      20   0  118m 2460 1864 S  0.0  0.0   0:00.00 crond              
 2872 root      20   0 19188  168    0 S  0.0  0.0   0:00.00 atd                
 2895 root      20   0  6508 1768 1644 S  0.0  0.0   0:00.02 agetty             
 2896 root      20   0  4360 1504 1408 S  0.0  0.0   0:00.02 mingetty           
 2900 root      20   0  4360 1504 1408 S  0.0  0.0   0:00.00 mingetty           
 2903 root      20   0  4360 1464 1372 S  0.0  0.0   0:00.00 mingetty           
 2905 root      20   0  4360 1452 1360 S  0.0  0.0   0:00.00 mingetty           
 2907 root      20   0  4360 1456 1360 S  0.0  0.0   0:00.00 mingetty           
 2909 root      20   0  4360 1492 1400 S  0.0  0.0   0:00.00 mingetty           
 2919 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/2:1H       
 2999 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/7:1H       
 3077 root      20   0  975m  72m  44m S  0.0  0.5   0:01.51 dockerd            
 3091 root      20   0  629m  23m  18m S  0.0  0.2   0:02.61 docker-containe    
 3108 root      20   0 11472 2252 1344 S  0.0  0.0   0:00.00 udevd              
 3109 root      20   0 11472 2228 1328 S  0.0  0.0   0:00.00 udevd              
 3240 root      20   0  117m 7184 6096 S  0.0  0.0   0:00.00 sshd               
 3242 ec2-user  20   0  117m 4004 2920 S  0.0  0.0   0:00.16 sshd               
 3243 ec2-user  20   0  112m 3504 3088 S  0.0  0.0   0:00.03 bash               
 3321 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/6:1H       
 3378 ec2-user  20   0 19996 2528 2260 S  0.0  0.0   0:00.00 tmux               
 3380 ec2-user  20   0 22444 3096 2516 S  0.0  0.0   0:00.20 tmux               
 3381 ec2-user  20   0  112m 3440 3048 S  0.0  0.0   0:00.00 bash               
 3402 ec2-user  20   0  105m 2280 2048 S  0.0  0.0   0:00.00 make               
 3470 ec2-user  20   0  110m 3060 2808 S  0.0  0.0   0:00.00 interpreter.sh     
 3482 ec2-user  20   0  110m 2268 2016 S  0.0  0.0   0:00.00 interpreter.sh     
 3572 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/1:1H       
 3621 ec2-user  20   0  110m 3020 2768 S  0.0  0.0   0:00.00 interpreter.sh     
 3633 ec2-user  20   0  110m 2200 1948 S  0.0  0.0   0:00.00 interpreter.sh     
 3756 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/4:1H       
 3758 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/u30:0      

6.5.4 より良い RNNLM の実装

  • LSTM レイヤの多層化
  • Dropout を使用
  • 重み共有

%python3
import sys
sys.path.append('/tmp/deep-learning-from-scratch-2')
from common.time_layers import TimeEmbedding, TimeDropout, TimeAffine, TimeSoftmaxWithLoss
from common.np import *
from common.base_model import BaseModel


class BetterRnnlm(BaseModel):
    def __init__(self, vocab_size=10000, wordvec_size=650, hidden_size=650, dropout_ratio=0.5):
        V, D, H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        
        embed_W = (rn(V, D) / 100).astype('f')
        lstm_Wx1 = (rn(D, 4*H) / np.sqrt(D)).astype('f')
        lstm_Wh1 = (rn(H, 4*H) / np.sqrt(H)).astype('f')
        lstm_b1 = np.zeros(4*H).astype('f')
        lstm_Wx2 = (rn(D, 4*H) / np.sqrt(D)).astype('f')
        lstm_Wh2 = (rn(H, 4*H) / np.sqrt(H)).astype('f')
        lstm_b2 = np.zeros(4*H).astype('f')
        affine_b = np.zeros(V).astype('f')
        
        self.layers = [
            TimeEmbedding(embed_W),
            TimeDropout(dropout_ratio),
            TimeLSTM(lstm_Wx1, lstm_Wh1, lstm_b1, stateful=True),
            TimeDropout(dropout_ratio),
            TimeLSTM(lstm_Wx2, lstm_Wh2, lstm_b2, stateful=True),
            TimeDropout(dropout_ratio),
            TimeAffine(embed_W.T, affine_b)
        ]
        self.loss_layer = TimeSoftmaxWithLoss()
        self.lstm_layers = [self.layers[2], self.layers[4]]
        self.drop_layers = [self.layers[1], self.layers[3], self.layers[5]]
        
        self.params, self.grads = [], []
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads
    
    def predict(self, xs, train_flg=False):
        for layer in self.drop_layers:
            layer.train_flg = train_flg
        for layer in self.layers:
            xs = layer.forward(xs)
        return xs
    
    def forward(self, xs, ts, train_flg=True):
        score = self.predict(xs, train_flg)
        loss = self.loss_layer.forward(score, ts)
        return loss
    
    def backward(self, dout=1):
        dout = self.loss_layer.backward(dout)
        for layer in reversed(self.layers):
            dout = layer.backward(dout)
        return dout
    
    def reset_state(self):
        for layer in self.lstm_layers:
            layer.reset_state()

%python3
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb


batch_size = 20
wordvec_size = 650
hidden_size = 650
time_size = 35
lr = 20.0
max_epoch = 40
max_grad = 0.25
dropout = 0.5

corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_val, _, _ = ptb.load_data('val')
corpus_test, _, _ = ptb.load_data('test')

vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

best_ppl = float('inf')
for epoch in range(max_epoch):
    trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size, time_size=time_size, max_grad=max_grad)
    
    model.reset_state()
    ppl = eval_perplexity(model, corpus_val)
    print('valid perplexity: ', ppl)
    
    if best_ppl > ppl:
        best_ppl = ppl
        model.save_params()
    else:
        lr /= 4.0
        optimizer.lr = lr
    
    model.reset_state()
    print('-' * 50)
| epoch 1 |  iter 1 / 1327 | time 1[s] | perplexity 10000.27
| epoch 1 |  iter 21 / 1327 | time 30[s] | perplexity 3663.54

レイヤーが増えるとGPUのほうがいい?

CUDA: バージョン確認

%sh
nvidia-smi
nvcc --version
Sun Nov 18 12:35:56 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:1B.0 Off |                    0 |
| N/A   34C    P0    25W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:00:1C.0 Off |                    0 |
| N/A   33C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
| N/A   32C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   33C    P0    27W / 300W |      0MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2017 NVIDIA Corporation
Built on Fri_Sep__1_21:08:03_CDT_2017
Cuda compilation tools, release 9.0, V9.0.176

cupyのインストール

%sh
sudo pip-3.6 install cupy-cuda90
Collecting cupy-cuda90
  Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB)
Collecting fastrlock>=0.3 (from cupy-cuda90)
  Downloading https://files.pythonhosted.org/packages/b5/93/a7efbd39eac46c137500b37570c31dedc2d31a8ff4949fcb90bda5bc5f16/fastrlock-0.4-cp36-cp36m-manylinux1_x86_64.whl
Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90)
Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90)
Installing collected packages: fastrlock, cupy-cuda90
Successfully installed cupy-cuda90-5.0.0 fastrlock-0.4
You are using pip version 9.0.3, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

%python3
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from dataset import ptb


batch_size = 20
wordvec_size = 650
hidden_size = 650
time_size = 35
lr = 20.0
max_epoch = 40
max_grad = 0.25
dropout = 0.5

corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_val, _, _ = ptb.load_data('val')
corpus_test, _, _ = ptb.load_data('test')

vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

model = BetterRnnlm(vocab_size, wordvec_size, hidden_size, dropout)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

best_ppl = float('inf')
for epoch in range(max_epoch):
    trainer.fit(xs, ts, max_epoch=1, batch_size=batch_size, time_size=time_size, max_grad=max_grad)
    
    model.reset_state()
    ppl = eval_perplexity(model, corpus_val)
    print('valid perplexity: ', ppl)
    
    if best_ppl > ppl:
        best_ppl = ppl
        model.save_params()
    else:
        lr /= 4.0
        optimizer.lr = lr
    
    model.reset_state()
    print('-' * 50)

%sh
nvidia-smi
Sun Nov 18 12:48:46 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:1B.0 Off |                    0 |
| N/A   37C    P0    58W / 300W |   2821MiB / 16160MiB |     10%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:00:1C.0 Off |                    0 |
| N/A   33C    P0    36W / 300W |     11MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
| N/A   31C    P0    38W / 300W |     11MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    42W / 300W |     11MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      5761      C   python3                                     2810MiB |
+-----------------------------------------------------------------------------+

  • 競り負けて途中で落ちてしまった 🙃
  • 今回のケースは複数GPUは必要なさそう
  • メモリもだいぶ余っている

%md