LSTM / ゼロから作る Deep Learning 2

    Posted on 2018/11/18

    6章: ゲート付き RNN

    ゼロから作るDeep Learning (2)の読書メモです。5章ではシンプルな RNN を使って言語モデルを実装しました。6章では LSTM などのゲートと呼ばれる仕組みを加えた RNN を実装することで長期的な依存関係を学習できるニューラルネットワークを実装していきます。

    参考実装

    %sh
    rm -rf /tmp/deep-learning-from-scratch-2
    git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2 /tmp/deep-learning-from-scratch-2
    Cloning into '/tmp/deep-learning-from-scratch-2'...
    

    numpy と matplotlib のインストール

    %sh
    pip3 install numpy matplotlib
    Collecting numpy
      Downloading https://files.pythonhosted.org/packages/86/04/bd774106ae0ae1ada68c67efe89f1a16b2aa373cc2db15d974002a9f136d/numpy-1.15.4-cp35-cp35m-manylinux1_x86_64.whl (13.8MB)
    Collecting matplotlib
      Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB)
    Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib)
      Downloading https://files.pythonhosted.org/packages/71/e8/6777f6624681c8b9701a8a0a5654f3eb56919a01a78e12bf3c73f5a3c714/pyparsing-2.3.0-py2.py3-none-any.whl (59kB)
    Collecting cycler>=0.10 (from matplotlib)
      Using cached https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
    Collecting python-dateutil>=2.1 (from matplotlib)
      Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB)
    Collecting kiwisolver>=1.0.1 (from matplotlib)
      Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB)
    Collecting six (from cycler>=0.10->matplotlib)
      Downloading https://files.pythonhosted.org/packages/67/4b/141a581104b1f6397bfa78ac9d43d8ad29a7ca43ea90a2d863fe3056e86a/six-1.11.0-py2.py3-none-any.whl
    Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib)
    Installing collected packages: numpy, pyparsing, six, cycler, python-dateutil, kiwisolver, matplotlib
    Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.15.4 pyparsing-2.3.0 python-dateutil-2.7.5 six-1.11.0
    You are using pip version 8.1.1, however version 18.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    %sh
    pip3 install seaborn
    Collecting seaborn
      Downloading https://files.pythonhosted.org/packages/a8/76/220ba4420459d9c4c9c9587c6ce607bf56c25b3d3d2de62056efe482dadc/seaborn-0.9.0-py3-none-any.whl (208kB)
    Collecting scipy>=0.14.0 (from seaborn)
      Downloading https://files.pythonhosted.org/packages/cd/32/5196b64476bd41d596a8aba43506e2403e019c90e1a3dfc21d51b83db5a6/scipy-1.1.0-cp35-cp35m-manylinux1_x86_64.whl (33.1MB)
    Collecting pandas>=0.15.2 (from seaborn)
      Downloading https://files.pythonhosted.org/packages/5d/d4/6e9c56a561f1d27407bf29318ca43f36ccaa289271b805a30034eb3a8ec4/pandas-0.23.4-cp35-cp35m-manylinux1_x86_64.whl (8.7MB)
    Requirement already satisfied (use --upgrade to upgrade): matplotlib>=1.4.3 in /usr/local/lib/python3.5/dist-packages (from seaborn)
    Requirement already satisfied (use --upgrade to upgrade): numpy>=1.9.3 in /usr/local/lib/python3.5/dist-packages (from seaborn)
    Collecting pytz>=2011k (from pandas>=0.15.2->seaborn)
      Downloading https://files.pythonhosted.org/packages/f8/0e/2365ddc010afb3d79147f1dd544e5ee24bf4ece58ab99b16fbb465ce6dc0/pytz-2018.7-py2.py3-none-any.whl (506kB)
    Requirement already satisfied (use --upgrade to upgrade): python-dateutil>=2.5.0 in /usr/local/lib/python3.5/dist-packages (from pandas>=0.15.2->seaborn)
    Requirement already satisfied (use --upgrade to upgrade): pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
    Requirement already satisfied (use --upgrade to upgrade): kiwisolver>=1.0.1 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
    Requirement already satisfied (use --upgrade to upgrade): cycler>=0.10 in /usr/local/lib/python3.5/dist-packages (from matplotlib>=1.4.3->seaborn)
    Requirement already satisfied (use --upgrade to upgrade): six>=1.5 in /usr/local/lib/python3.5/dist-packages (from python-dateutil>=2.5.0->pandas>=0.15.2->seaborn)
    Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib>=1.4.3->seaborn)
    Installing collected packages: scipy, pytz, pandas, seaborn
    Successfully installed pandas-0.23.4 pytz-2018.7 scipy-1.1.0 seaborn-0.9.0
    You are using pip version 8.1.1, however version 18.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    6.1.3: 勾配消失もしくは勾配爆発の原因

    %python3
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    def plot_tanh(x):
        y = np.tanh(x)
        plt.plot(x, y, label='tanh')
    
    def plot_dydx(x):
        y = 1 / (np.cosh(x)**2)
        plt.plot(x, y, '--', label='dy/dx')
    
    sns.set()
    x = np.arange(-5, 5, 0.2)
    plt.xlabel('x')
    plt.ylabel('y')
    plot_tanh(x)
    plot_dydx(x)
    
    plt.legend(loc='lower right', fontsize=18, borderaxespad=1)
    <matplotlib.legend.Legend object at 0x7f8c13198ba8>
    

    • \(y=tanh(x)\) を微分した関数は \(x\) が 0 から遠ざかるにつれて小さくなるので逆伝搬で tanh ノードを通るたびに出力される値は小さくなる
    • ReLU を使うことで勾配消失を抑えることができるらしい

    %python3
    N = 2
    H = 3
    T = 20
    
    dh = np.ones((N, H))
    np.random.seed(3)
    Wh = np.random.randn(H, H)
    
    norm_list = []
    for t in range(T):
        dh = np.dot(dh, Wh.T)
        norm = np.sqrt(np.sum(dh**2)) / N
        norm_list.append(norm)
    
    sns.set()
    plt.xlabel('time step')
    plt.ylabel('norm')
    plt.plot(norm_list)
    [<matplotlib.lines.Line2D object at 0x7f8bfca6c5f8>]
    

    %python3
    N = 2
    H = 3
    T = 20
    
    dh = np.ones((N, H))
    np.random.seed(3)
    Wh = np.random.randn(H, H) * 0.5
    
    norm_list = []
    for t in range(T):
        dh = np.dot(dh, Wh.T)
        norm = np.sqrt(np.sum(dh**2)) / N
        norm_list.append(norm)
    
    sns.set()
    plt.xlabel('time step')
    plt.ylabel('norm')
    plt.plot(norm_list)
    [<matplotlib.lines.Line2D object at 0x7f8c11092e80>]
    

    • MatMul ノードについても同様に計算を繰り返すと勾配爆発や勾配消失が起こる
    • 勾配爆発の対策
      • 勾配クリッピング
      • 勾配の L2 ノルムがしきい値を超えた場合に勾配を修正する
    • 勾配消失の対策
      • ゲートを導入する

    ゲート

    • 各要素に対して次の状態としてどれだけ重要かを調整するもの
    • output ゲート

    $$
    o = σ(x_t W_{x}^{(o)} + h_{t-1}W_{h}^{(o)} + b^{(o)})
    $$

    • output ゲートと tanh ノードの各要素の積(アマダール積)を出力
    • 各ゲートがそれぞれ重みを持っている
    • 行列演算を行うときは各ゲートの演算をまとめることができる

    LSTMレイヤー: 実装

    %python3
    import sys
    sys.path.append('/tmp/deep-learning-from-scratch-2')
    
    from common import config
    config.GPU = True
    from common.np import *
    from common.functions import sigmoid
    
    
    class LSTM:
        def __init__(self, Wx, Wh, b):
            self.params = [Wx, Wh, b]
            self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
            self.cache = None
    
        def forward(self, x, h_prev, c_prev):
            Wx, Wh, b = self.params
            N, H = h_prev.shape
    
            A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b
    
            f = A[:, :H]
            g = A[:, H:2*H]
            i = A[:, 2*H:3*H]
            o = A[:, 3*H:]
    
            f = sigmoid(f)
            g = np.tanh(g)
            i = sigmoid(i)
            o = sigmoid(o)
    
            c_next = f * c_prev + g * i
            h_next = o * np.tanh(c_next)
    
            self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)
            return h_next, c_next
    
        def backward(self, dh_next, dc_next):
            Wx, Wh, b = self.params
            x, h_prev, c_prev, i, f, g, o, c_next = self.cache
    
            tanh_c_next = np.tanh(c_next)
    
            ds = dc_next + (dh_next * o) * (1 - tanh_c_next ** 2)
    
            dc_prev = ds * f
    
            di = ds * g
            df = ds * c_prev
            do = dh_next * tanh_c_next
            dg = ds * i
    
            di *= i * (1 - i)
            df *= f * (1 - f)
            do *= o * (1 - o)
            dg *= (1 - g ** 2)
    
            dA = np.hstack((df, dg, di, do))
    
            dWh = np.dot(h_prev.T, dA)
            dWx = np.dot(x.T, dA)
            db = dA.sum(axis=0)
    
            self.grads[0][...] = dWx
            self.grads[1][...] = dWh
            self.grads[2][...] = db
    
            dx = np.dot(dA, Wx.T)
            dh_prev = np.dot(dA, Wh.T)
    
            return dx, dh_prev, dc_prev
    ------------------------------------------------------------
                           GPU Mode (cupy)
    ------------------------------------------------------------
    
    

    TimeLSTMレイヤー: 実装

    %python3
    class TimeLSTM:
        def __init__(self, Wx, Wh, b, stateful=False):
            self.params = [Wx, Wh, b]
            self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
            self.layers = None
    
            self.h, self.c = None, None
            self.dh = None
            self.stateful = stateful
    
        def forward(self, xs):
            Wx, Wh, b = self.params
            N, T, D = xs.shape
            H = Wh.shape[0]
    
            self.layers = []
            hs = np.empty((N, T, H), dtype='f')
    
            if not self.stateful or self.h is None:
                self.h = np.zeros((N, H), dtype='f')
            if not self.stateful or self.c is None:
                self.c = np.zeros((N, H), dtype='f')
    
            for t in range(T):
                layer = LSTM(*self.params)
                self.h, self.c = layer.forward(xs[:, t, :], self.h, self.c)
                hs[:, t, :] = self.h
    
                self.layers.append(layer)
    
            return hs
    
        def backward(self, dhs):
            Wx, Wh, b = self.params
            N, T, H = dhs.shape
            D = Wx.shape[0]
    
            dxs = np.empty((N, T, D), dtype='f')
            dh, dc = 0, 0
    
            grads = [0, 0, 0]
            for t in reversed(range(T)):
                layer = self.layers[t]
                dx, dh, dc = layer.backward(dhs[:, t, :] + dh, dc)
                dxs[:, t, :] = dx
                for i, grad in enumerate(layer.grads):
                    grads[i] += grad
    
            for i, grad in enumerate(grads):
                self.grads[i][...] = grad
            self.dh = dh
            return dxs
    
        def set_state(self, h, c=None):
            self.h, self.c = h, c
    
        def reset_state(self):
            self.h, self.c = None, None

    RNNLM: 実装

    %python3
    import pickle
    from common.time_layers import TimeSoftmaxWithLoss, TimeEmbedding, TimeAffine
    
    class Rnnlm:
        def __init__(self, vocab_size=10000, wordvec_size=100, hidden_size=100):
            V, D, H = vocab_size, wordvec_size, hidden_size
            rn = np.random.randn
            
            embed_W = (rn(V, D) / 100).astype('f')
            lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f')
            lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f')
            lstm_b = np.zeros(4 * H).astype('f')
            affine_W = (rn(H, V) / np.sqrt(H)).astype('f')
            affine_b = np.zeros(V).astype('f')
            
            self.layers = [
                TimeEmbedding(embed_W),
                TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=True),
                TimeAffine(affine_W, affine_b)
            ]
            self.loss_layer = TimeSoftmaxWithLoss()
            self.lstm_layer = self.layers[1]
            
            self.params, self.grads = [], []
            for layer in self.layers:
                self.params += layer.params
                self.grads += layer.grads
    
        def predict(self, xs):
            for layer in self.layers:
                xs = layer.forward(xs)
            return xs
        
        def forward(self, xs, ts):
            score = self.predict(xs)
            loss = self.loss_layer.forward(score, ts)
            return loss
        
        def backward(self, dout=1):
            dout = self.loss_layer.backward(dout)
            for layer in reversed(self.layers):
                dout = layer.backward(dout)
            return dout
        
        def reset_state(self):
            self.lstm_layer.reset_state()
        
        def save_params(self, file_name='Rnnlm.pkl'):
            with open(file_name, 'wb') as f:
                pickle.dump(self.params, f)
        
        def load_params(self, file_name='Rnnlm.pkl'):
            with open(file_name, 'rb') as f:
                self.params = pickle.load(f)

    CUDAを使う場合

    %sh
    nvidia-smi
    Thu Nov 15 13:21:21 2018       
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |===============================+======================+======================|
    |   0  Tesla K80           On   | 00000000:00:1E.0 Off |                    0 |
    | N/A   56C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                       GPU Memory |
    |  GPU       PID   Type   Process name                             Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+
    

    %sh
    nvcc --version
    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2017 NVIDIA Corporation
    Built on Fri_Sep__1_21:08:03_CDT_2017
    Cuda compilation tools, release 9.0, V9.0.176
    

    %sh
    sudo pip-3.6 install cupy-cuda90
    Collecting cupy-cuda90
      Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB)
    Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90)
    Requirement already satisfied: fastrlock>=0.3 in /usr/local/lib64/python3.6/site-packages (from cupy-cuda90)
    Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90)
    Installing collected packages: cupy-cuda90
    Successfully installed cupy-cuda90-5.0.0
    You are using pip version 9.0.3, however version 18.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    トレーニング実行

    %python3
    
    from common.optimizer import SGD
    from common.trainer import RnnlmTrainer
    from common.util import eval_perplexity
    from dataset import ptb
    
    # hyper parameters
    batch_size = 20
    wordvec_size = 100
    hidden_size = 100
    time_size = 35
    lr = 20.0
    max_epoch = 4
    max_grad = 0.25
    
    # load dataset
    corpus, word_to_id, id_to_word = ptb.load_data('train')
    corpus_test, _, _ = ptb.load_data('test')
    vocab_size = len(word_to_id)
    xs = corpus[:-1]
    ts = corpus[1:]
    
    # generate model
    model = Rnnlm(vocab_size, wordvec_size, hidden_size)
    optimizer = SGD(lr)
    trainer = RnnlmTrainer(model, optimizer)
    
    trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20)
    trainer.plot(ylim=(0, 500))
    
    model.reset_state()
    ppl_test = eval_perplexity(model, corpus_test)
    print('test perplexity: ', ppl_test)
    
    model.save_params()
    | epoch 1 |  iter 1 / 1327 | time 12[s] | perplexity 10000.52
    | epoch 1 |  iter 21 / 1327 | time 18[s] | perplexity 3043.08
    | epoch 1 |  iter 41 / 1327 | time 24[s] | perplexity 1223.89
    | epoch 1 |  iter 61 / 1327 | time 30[s] | perplexity 988.30
    | epoch 1 |  iter 81 / 1327 | time 36[s] | perplexity 812.58
    | epoch 1 |  iter 101 / 1327 | time 42[s] | perplexity 685.93
    | epoch 1 |  iter 121 / 1327 | time 48[s] | perplexity 628.84
    | epoch 1 |  iter 141 / 1327 | time 54[s] | perplexity 607.25
    | epoch 1 |  iter 161 / 1327 | time 60[s] | perplexity 593.01
    | epoch 1 |  iter 181 / 1327 | time 66[s] | perplexity 572.24
    | epoch 1 |  iter 201 / 1327 | time 72[s] | perplexity 516.94
    | epoch 1 |  iter 221 / 1327 | time 78[s] | perplexity 500.29
    | epoch 1 |  iter 241 / 1327 | time 84[s] | perplexity 441.31
    | epoch 1 |  iter 261 / 1327 | time 90[s] | perplexity 466.61
    | epoch 1 |  iter 281 / 1327 | time 95[s] | perplexity 457.95
    | epoch 1 |  iter 301 / 1327 | time 101[s] | perplexity 400.09
    | epoch 1 |  iter 321 / 1327 | time 107[s] | perplexity 354.01
    | epoch 1 |  iter 341 / 1327 | time 113[s] | perplexity 414.76
    | epoch 1 |  iter 361 / 1327 | time 119[s] | perplexity 413.01
    | epoch 1 |  iter 381 / 1327 | time 125[s] | perplexity 338.24
    | epoch 1 |  iter 401 / 1327 | time 131[s] | perplexity 358.44
    | epoch 1 |  iter 421 / 1327 | time 137[s] | perplexity 351.17
    | epoch 1 |  iter 441 / 1327 | time 143[s] | perplexity 332.46
    | epoch 1 |  iter 461 / 1327 | time 149[s] | perplexity 332.15
    | epoch 1 |  iter 481 / 1327 | time 155[s] | perplexity 308.20
    | epoch 1 |  iter 501 / 1327 | time 161[s] | perplexity 315.23
    | epoch 1 |  iter 521 / 1327 | time 167[s] | perplexity 307.01
    | epoch 1 |  iter 541 / 1327 | time 173[s] | perplexity 321.42
    | epoch 1 |  iter 561 / 1327 | time 179[s] | perplexity 292.32
    | epoch 1 |  iter 581 / 1327 | time 184[s] | perplexity 264.95
    | epoch 1 |  iter 601 / 1327 | time 190[s] | perplexity 342.06
    | epoch 1 |  iter 621 / 1327 | time 196[s] | perplexity 317.83
    | epoch 1 |  iter 641 / 1327 | time 202[s] | perplexity 288.22
    | epoch 1 |  iter 661 / 1327 | time 208[s] | perplexity 273.88
    | epoch 1 |  iter 681 / 1327 | time 214[s] | perplexity 230.15
    | epoch 1 |  iter 701 / 1327 | time 220[s] | perplexity 253.12
    | epoch 1 |  iter 721 / 1327 | time 226[s] | perplexity 265.56
    | epoch 1 |  iter 741 / 1327 | time 232[s] | perplexity 224.49
    | epoch 1 |  iter 761 / 1327 | time 238[s] | perplexity 236.62
    | epoch 1 |  iter 781 / 1327 | time 244[s] | perplexity 224.24
    | epoch 1 |  iter 801 / 1327 | time 250[s] | perplexity 243.74
    | epoch 1 |  iter 821 / 1327 | time 256[s] | perplexity 227.72
    | epoch 1 |  iter 841 / 1327 | time 262[s] | perplexity 231.35
    | epoch 1 |  iter 861 / 1327 | time 267[s] | perplexity 223.62
    | epoch 1 |  iter 881 / 1327 | time 273[s] | perplexity 206.62
    | epoch 1 |  iter 901 / 1327 | time 279[s] | perplexity 256.72
    | epoch 1 |  iter 921 / 1327 | time 285[s] | perplexity 227.62
    | epoch 1 |  iter 941 / 1327 | time 291[s] | perplexity 232.06
    | epoch 1 |  iter 961 / 1327 | time 297[s] | perplexity 244.22
    | epoch 1 |  iter 981 / 1327 | time 303[s] | perplexity 230.51
    | epoch 1 |  iter 1001 / 1327 | time 309[s] | perplexity 194.96
    | epoch 1 |  iter 1021 / 1327 | time 315[s] | perplexity 226.91
    | epoch 1 |  iter 1041 / 1327 | time 321[s] | perplexity 209.66
    | epoch 1 |  iter 1061 / 1327 | time 327[s] | perplexity 199.37
    | epoch 1 |  iter 1081 / 1327 | time 333[s] | perplexity 170.13
    | epoch 1 |  iter 1101 / 1327 | time 339[s] | perplexity 191.02
    | epoch 1 |  iter 1121 / 1327 | time 345[s] | perplexity 231.76
    | epoch 1 |  iter 1141 / 1327 | time 351[s] | perplexity 209.41
    | epoch 1 |  iter 1161 / 1327 | time 356[s] | perplexity 200.44
    | epoch 1 |  iter 1181 / 1327 | time 362[s] | perplexity 192.77
    | epoch 1 |  iter 1201 / 1327 | time 368[s] | perplexity 165.20
    | epoch 1 |  iter 1221 / 1327 | time 374[s] | perplexity 160.68
    | epoch 1 |  iter 1241 / 1327 | time 380[s] | perplexity 187.82
    | epoch 1 |  iter 1261 / 1327 | time 386[s] | perplexity 173.45
    | epoch 1 |  iter 1281 / 1327 | time 392[s] | perplexity 179.21
    | epoch 1 |  iter 1301 / 1327 | time 398[s] | perplexity 224.94
    | epoch 1 |  iter 1321 / 1327 | time 404[s] | perplexity 211.44
    | epoch 2 |  iter 1 / 1327 | time 406[s] | perplexity 225.14
    | epoch 2 |  iter 21 / 1327 | time 412[s] | perplexity 204.59
    | epoch 2 |  iter 41 / 1327 | time 418[s] | perplexity 191.89
    | epoch 2 |  iter 61 / 1327 | time 424[s] | perplexity 178.07
    | epoch 2 |  iter 81 / 1327 | time 430[s] | perplexity 160.75
    | epoch 2 |  iter 101 / 1327 | time 436[s] | perplexity 153.10
    | epoch 2 |  iter 121 / 1327 | time 441[s] | perplexity 162.32
    | epoch 2 |  iter 141 / 1327 | time 447[s] | perplexity 179.64
    | epoch 2 |  iter 161 / 1327 | time 453[s] | perplexity 194.35
    | epoch 2 |  iter 181 / 1327 | time 459[s] | perplexity 202.43
    | epoch 2 |  iter 201 / 1327 | time 465[s] | perplexity 188.74
    | epoch 2 |  iter 221 / 1327 | time 471[s] | perplexity 185.50
    | epoch 2 |  iter 241 / 1327 | time 477[s] | perplexity 178.58
    | epoch 2 |  iter 261 / 1327 | time 483[s] | perplexity 186.30
    | epoch 2 |  iter 281 / 1327 | time 489[s] | perplexity 187.32
    | epoch 2 |  iter 301 / 1327 | time 495[s] | perplexity 167.80
    | epoch 2 |  iter 321 / 1327 | time 501[s] | perplexity 139.48
    | epoch 2 |  iter 341 / 1327 | time 507[s] | perplexity 171.67
    | epoch 2 |  iter 361 / 1327 | time 512[s] | perplexity 197.60
    | epoch 2 |  iter 381 / 1327 | time 518[s] | perplexity 154.82
    | epoch 2 |  iter 401 / 1327 | time 524[s] | perplexity 169.48
    | epoch 2 |  iter 421 / 1327 | time 530[s] | perplexity 155.43
    | epoch 2 |  iter 441 / 1327 | time 536[s] | perplexity 164.11
    | epoch 2 |  iter 461 / 1327 | time 542[s] | perplexity 158.15
    | epoch 2 |  iter 481 / 1327 | time 548[s] | perplexity 158.01
    | epoch 2 |  iter 501 / 1327 | time 554[s] | perplexity 169.49
    | epoch 2 |  iter 521 / 1327 | time 560[s] | perplexity 172.82
    | epoch 2 |  iter 541 / 1327 | time 566[s] | perplexity 175.41
    | epoch 2 |  iter 561 / 1327 | time 572[s] | perplexity 157.42
    | epoch 2 |  iter 581 / 1327 | time 578[s] | perplexity 140.03
    | epoch 2 |  iter 601 / 1327 | time 584[s] | perplexity 192.35
    | epoch 2 |  iter 621 / 1327 | time 589[s] | perplexity 182.63
    | epoch 2 |  iter 641 / 1327 | time 595[s] | perplexity 165.24
    | epoch 2 |  iter 661 / 1327 | time 601[s] | perplexity 154.49
    | epoch 2 |  iter 681 / 1327 | time 607[s] | perplexity 129.85
    | epoch 2 |  iter 701 / 1327 | time 613[s] | perplexity 153.00
    | epoch 2 |  iter 721 / 1327 | time 619[s] | perplexity 160.11
    | epoch 2 |  iter 741 / 1327 | time 625[s] | perplexity 134.15
    | epoch 2 |  iter 761 / 1327 | time 631[s] | perplexity 132.44
    | epoch 2 |  iter 781 / 1327 | time 637[s] | perplexity 135.70
    | epoch 2 |  iter 801 / 1327 | time 643[s] | perplexity 147.10
    | epoch 2 |  iter 821 / 1327 | time 649[s] | perplexity 145.80
    | epoch 2 |  iter 841 / 1327 | time 655[s] | perplexity 144.89
    | epoch 2 |  iter 861 / 1327 | time 660[s] | perplexity 146.55
    | epoch 2 |  iter 881 / 1327 | time 666[s] | perplexity 129.88
    | epoch 2 |  iter 901 / 1327 | time 672[s] | perplexity 167.04
    | epoch 2 |  iter 921 / 1327 | time 678[s] | perplexity 146.77
    | epoch 2 |  iter 941 / 1327 | time 684[s] | perplexity 152.44
    | epoch 2 |  iter 961 / 1327 | time 690[s] | perplexity 164.83
    | epoch 2 |  iter 981 / 1327 | time 696[s] | perplexity 153.86
    | epoch 2 |  iter 1001 / 1327 | time 702[s] | perplexity 132.51
    | epoch 2 |  iter 1021 / 1327 | time 708[s] | perplexity 155.63
    | epoch 2 |  iter 1041 / 1327 | time 714[s] | perplexity 143.63
    | epoch 2 |  iter 1061 / 1327 | time 720[s] | perplexity 130.54
    | epoch 2 |  iter 1081 / 1327 | time 726[s] | perplexity 111.60
    | epoch 2 |  iter 1101 / 1327 | time 732[s] | perplexity 120.41
    | epoch 2 |  iter 1121 / 1327 | time 738[s] | perplexity 153.38
    | epoch 2 |  iter 1141 / 1327 | time 743[s] | perplexity 143.03
    | epoch 2 |  iter 1161 / 1327 | time 749[s] | perplexity 132.85
    | epoch 2 |  iter 1181 / 1327 | time 755[s] | perplexity 135.65
    | epoch 2 |  iter 1201 / 1327 | time 761[s] | perplexity 113.97
    | epoch 2 |  iter 1221 / 1327 | time 767[s] | perplexity 108.89
    | epoch 2 |  iter 1241 / 1327 | time 773[s] | perplexity 130.19
    | epoch 2 |  iter 1261 / 1327 | time 779[s] | perplexity 125.34
    | epoch 2 |  iter 1281 / 1327 | time 785[s] | perplexity 122.60
    | epoch 2 |  iter 1301 / 1327 | time 791[s] | perplexity 159.44
    | epoch 2 |  iter 1321 / 1327 | time 797[s] | perplexity 152.86
    | epoch 3 |  iter 1 / 1327 | time 799[s] | perplexity 161.38
    | epoch 3 |  iter 21 / 1327 | time 805[s] | perplexity 145.38
    | epoch 3 |  iter 41 / 1327 | time 811[s] | perplexity 137.33
    | epoch 3 |  iter 61 / 1327 | time 817[s] | perplexity 128.55
    | epoch 3 |  iter 81 / 1327 | time 822[s] | perplexity 117.85
    | epoch 3 |  iter 101 / 1327 | time 828[s] | perplexity 105.79
    | epoch 3 |  iter 121 / 1327 | time 834[s] | perplexity 115.92
    | epoch 3 |  iter 141 / 1327 | time 840[s] | perplexity 127.47
    | epoch 3 |  iter 161 / 1327 | time 846[s] | perplexity 143.02
    | epoch 3 |  iter 181 / 1327 | time 852[s] | perplexity 152.12
    | epoch 3 |  iter 201 / 1327 | time 858[s] | perplexity 141.67
    | epoch 3 |  iter 221 / 1327 | time 864[s] | perplexity 141.38
    | epoch 3 |  iter 241 / 1327 | time 870[s] | perplexity 135.04
    | epoch 3 |  iter 261 / 1327 | time 876[s] | perplexity 139.20
    | epoch 3 |  iter 281 / 1327 | time 881[s] | perplexity 142.71
    | epoch 3 |  iter 301 / 1327 | time 887[s] | perplexity 125.64
    | epoch 3 |  iter 321 / 1327 | time 893[s] | perplexity 102.47
    | epoch 3 |  iter 341 / 1327 | time 899[s] | perplexity 125.40
    | epoch 3 |  iter 361 / 1327 | time 905[s] | perplexity 154.14
    | epoch 3 |  iter 381 / 1327 | time 911[s] | perplexity 115.68
    | epoch 3 |  iter 401 / 1327 | time 917[s] | perplexity 130.53
    | epoch 3 |  iter 421 / 1327 | time 923[s] | perplexity 113.71
    | epoch 3 |  iter 441 / 1327 | time 929[s] | perplexity 123.27
    | epoch 3 |  iter 461 / 1327 | time 935[s] | perplexity 117.73
    | epoch 3 |  iter 481 / 1327 | time 941[s] | perplexity 119.94
    | epoch 3 |  iter 501 / 1327 | time 946[s] | perplexity 128.34
    | epoch 3 |  iter 521 / 1327 | time 952[s] | perplexity 137.33
    | epoch 3 |  iter 541 / 1327 | time 958[s] | perplexity 136.05
    | epoch 3 |  iter 561 / 1327 | time 964[s] | perplexity 119.85
    | epoch 3 |  iter 581 / 1327 | time 970[s] | perplexity 106.57
    | epoch 3 |  iter 601 / 1327 | time 976[s] | perplexity 152.18
    | epoch 3 |  iter 621 / 1327 | time 982[s] | perplexity 143.81
    | epoch 3 |  iter 641 / 1327 | time 988[s] | perplexity 128.85
    | epoch 3 |  iter 661 / 1327 | time 994[s] | perplexity 121.50
    | epoch 3 |  iter 681 / 1327 | time 1000[s] | perplexity 99.55
    | epoch 3 |  iter 701 / 1327 | time 1006[s] | perplexity 120.48
    | epoch 3 |  iter 721 / 1327 | time 1011[s] | perplexity 126.11
    | epoch 3 |  iter 741 / 1327 | time 1017[s] | perplexity 107.88
    | epoch 3 |  iter 761 / 1327 | time 1023[s] | perplexity 103.60
    | epoch 3 |  iter 781 / 1327 | time 1029[s] | perplexity 103.77
    | epoch 3 |  iter 801 / 1327 | time 1035[s] | perplexity 114.90
    | epoch 3 |  iter 821 / 1327 | time 1041[s] | perplexity 116.49
    | epoch 3 |  iter 841 / 1327 | time 1047[s] | perplexity 114.93
    | epoch 3 |  iter 861 / 1327 | time 1053[s] | perplexity 119.74
    | epoch 3 |  iter 881 / 1327 | time 1059[s] | perplexity 106.93
    | epoch 3 |  iter 901 / 1327 | time 1065[s] | perplexity 133.25
    | epoch 3 |  iter 921 / 1327 | time 1071[s] | perplexity 118.78
    | epoch 3 |  iter 941 / 1327 | time 1076[s] | perplexity 126.31
    | epoch 3 |  iter 961 / 1327 | time 1082[s] | perplexity 133.07
    | epoch 3 |  iter 981 / 1327 | time 1088[s] | perplexity 123.01
    | epoch 3 |  iter 1001 / 1327 | time 1094[s] | perplexity 110.23
    | epoch 3 |  iter 1021 / 1327 | time 1100[s] | perplexity 129.15
    | epoch 3 |  iter 1041 / 1327 | time 1106[s] | perplexity 120.18
    | epoch 3 |  iter 1061 / 1327 | time 1112[s] | perplexity 104.30
    | epoch 3 |  iter 1081 / 1327 | time 1118[s] | perplexity 88.97
    | epoch 3 |  iter 1101 / 1327 | time 1124[s] | perplexity 95.64
    | epoch 3 |  iter 1121 / 1327 | time 1130[s] | perplexity 120.60
    | epoch 3 |  iter 1141 / 1327 | time 1136[s] | perplexity 115.53
    | epoch 3 |  iter 1161 / 1327 | time 1142[s] | perplexity 106.47
    | epoch 3 |  iter 1181 / 1327 | time 1148[s] | perplexity 112.18
    | epoch 3 |  iter 1201 / 1327 | time 1154[s] | perplexity 94.87
    | epoch 3 |  iter 1221 / 1327 | time 1159[s] | perplexity 87.82
    | epoch 3 |  iter 1241 / 1327 | time 1165[s] | perplexity 105.57
    | epoch 3 |  iter 1261 / 1327 | time 1171[s] | perplexity 105.54
    | epoch 3 |  iter 1281 / 1327 | time 1177[s] | perplexity 100.72
    | epoch 3 |  iter 1301 / 1327 | time 1183[s] | perplexity 131.29
    | epoch 3 |  iter 1321 / 1327 | time 1189[s] | perplexity 126.84
    | epoch 4 |  iter 1 / 1327 | time 1191[s] | perplexity 134.18
    | epoch 4 |  iter 21 / 1327 | time 1197[s] | perplexity 122.94
    | epoch 4 |  iter 41 / 1327 | time 1203[s] | perplexity 107.72
    | epoch 4 |  iter 61 / 1327 | time 1209[s] | perplexity 109.87
    | epoch 4 |  iter 81 / 1327 | time 1215[s] | perplexity 97.44
    | epoch 4 |  iter 101 / 1327 | time 1221[s] | perplexity 86.77
    | epoch 4 |  iter 121 / 1327 | time 1227[s] | perplexity 95.40
    | epoch 4 |  iter 141 / 1327 | time 1233[s] | perplexity 104.29
    | epoch 4 |  iter 161 / 1327 | time 1239[s] | perplexity 119.42
    | epoch 4 |  iter 181 / 1327 | time 1244[s] | perplexity 129.87
    | epoch 4 |  iter 201 / 1327 | time 1250[s] | perplexity 120.89
    | epoch 4 |  iter 221 / 1327 | time 1256[s] | perplexity 121.37
    | epoch 4 |  iter 241 / 1327 | time 1262[s] | perplexity 114.78
    | epoch 4 |  iter 261 / 1327 | time 1268[s] | perplexity 114.97
    | epoch 4 |  iter 281 / 1327 | time 1274[s] | perplexity 122.07
    | epoch 4 |  iter 301 / 1327 | time 1280[s] | perplexity 105.88
    | epoch 4 |  iter 321 / 1327 | time 1286[s] | perplexity 84.09
    | epoch 4 |  iter 341 / 1327 | time 1292[s] | perplexity 101.45
    | epoch 4 |  iter 361 / 1327 | time 1298[s] | perplexity 131.19
    | epoch 4 |  iter 381 / 1327 | time 1304[s] | perplexity 98.35
    | epoch 4 |  iter 401 / 1327 | time 1310[s] | perplexity 111.70
    | epoch 4 |  iter 421 / 1327 | time 1316[s] | perplexity 94.27
    | epoch 4 |  iter 441 / 1327 | time 1322[s] | perplexity 103.22
    | epoch 4 |  iter 461 / 1327 | time 1328[s] | perplexity 100.04
    | epoch 4 |  iter 481 / 1327 | time 1333[s] | perplexity 103.83
    | epoch 4 |  iter 501 / 1327 | time 1339[s] | perplexity 107.87
    | epoch 4 |  iter 521 / 1327 | time 1345[s] | perplexity 117.03
    | epoch 4 |  iter 541 / 1327 | time 1351[s] | perplexity 114.25
    | epoch 4 |  iter 561 / 1327 | time 1357[s] | perplexity 103.34
    | epoch 4 |  iter 581 / 1327 | time 1363[s] | perplexity 89.92
    | epoch 4 |  iter 601 / 1327 | time 1369[s] | perplexity 129.68
    | epoch 4 |  iter 621 / 1327 | time 1375[s] | perplexity 123.08
    | epoch 4 |  iter 641 / 1327 | time 1381[s] | perplexity 111.31
    | epoch 4 |  iter 661 / 1327 | time 1387[s] | perplexity 104.01
    | epoch 4 |  iter 681 / 1327 | time 1393[s] | perplexity 84.54
    | epoch 4 |  iter 701 / 1327 | time 1399[s] | perplexity 103.82
    | epoch 4 |  iter 721 / 1327 | time 1404[s] | perplexity 108.22
    | epoch 4 |  iter 741 / 1327 | time 1410[s] | perplexity 96.01
    | epoch 4 |  iter 761 / 1327 | time 1416[s] | perplexity 88.78
    | epoch 4 |  iter 781 / 1327 | time 1422[s] | perplexity 88.47
    | epoch 4 |  iter 801 / 1327 | time 1428[s] | perplexity 97.88
    | epoch 4 |  iter 821 / 1327 | time 1434[s] | perplexity 103.32
    | epoch 4 |  iter 841 / 1327 | time 1440[s] | perplexity 99.11
    | epoch 4 |  iter 861 / 1327 | time 1446[s] | perplexity 104.41
    | epoch 4 |  iter 881 / 1327 | time 1452[s] | perplexity 91.42
    | epoch 4 |  iter 901 / 1327 | time 1458[s] | perplexity 115.78
    | epoch 4 |  iter 921 / 1327 | time 1464[s] | perplexity 103.37
    | epoch 4 |  iter 941 / 1327 | time 1470[s] | perplexity 111.49
    | epoch 4 |  iter 961 / 1327 | time 1476[s] | perplexity 112.74
    | epoch 4 |  iter 981 / 1327 | time 1481[s] | perplexity 106.57
    | epoch 4 |  iter 1001 / 1327 | time 1487[s] | perplexity 98.34
    | epoch 4 |  iter 1021 / 1327 | time 1493[s] | perplexity 114.28
    | epoch 4 |  iter 1041 / 1327 | time 1499[s] | perplexity 104.62
    | epoch 4 |  iter 1061 / 1327 | time 1505[s] | perplexity 90.57
    | epoch 4 |  iter 1081 / 1327 | time 1511[s] | perplexity 78.93
    | epoch 4 |  iter 1101 / 1327 | time 1517[s] | perplexity 79.35
    | epoch 4 |  iter 1121 / 1327 | time 1523[s] | perplexity 103.15
    | epoch 4 |  iter 1141 / 1327 | time 1529[s] | perplexity 100.01
    | epoch 4 |  iter 1161 / 1327 | time 1535[s] | perplexity 92.35
    | epoch 4 |  iter 1181 / 1327 | time 1541[s] | perplexity 96.91
    | epoch 4 |  iter 1201 / 1327 | time 1547[s] | perplexity 84.74
    | epoch 4 |  iter 1221 / 1327 | time 1553[s] | perplexity 75.15
    | epoch 4 |  iter 1241 / 1327 | time 1558[s] | perplexity 91.74
    | epoch 4 |  iter 1261 / 1327 | time 1564[s] | perplexity 94.50
    | epoch 4 |  iter 1281 / 1327 | time 1570[s] | perplexity 89.46
    | epoch 4 |  iter 1301 / 1327 | time 1576[s] | perplexity 111.59
    | epoch 4 |  iter 1321 / 1327 | time 1582[s] | perplexity 110.69
    
    evaluating perplexity ... 0 / 235 1 / 235 2 / 235 3 / 235 4 / 235 5 / 235 6 / 235 7 / 235 8 / 235 9 / 235 10 / 235 11 / 235 12 / 235 13 / 235 14 / 235 15 / 235 16 / 235 17 / 235 18 / 235 19 / 235 20 / 235 21 / 235 22 / 235 23 / 235 24 / 235 25 / 235 26 / 235 27 / 235 28 / 235 29 / 235 30 / 235 31 / 235 32 / 235 33 / 235 34 / 235 35 / 235 36 / 235 37 / 235 38 / 235 39 / 235 40 / 235 41 / 235 42 / 235 43 / 235 44 / 235 45 / 235 46 / 235 47 / 235 48 / 235 49 / 235 50 / 235 51 / 235 52 / 235 53 / 235 54 / 235 55 / 235 56 / 235 57 / 235 58 / 235 59 / 235 60 / 235 61 / 235 62 / 235 63 / 235 64 / 235 65 / 235 66 / 235 67 / 235 68 / 235 69 / 235 70 / 235 71 / 235 72 / 235 73 / 235 74 / 235 75 / 235 76 / 235 77 / 235 78 / 235 79 / 235 80 / 235 81 / 235 82 / 235 83 / 235 84 / 235 85 / 235 86 / 235 87 / 235 88 / 235 89 / 235 90 / 235 91 / 235 92 / 235 93 / 235 94 / 235 95 / 235 96 / 235 97 / 235 98 / 235 99 / 235 100 / 235 101 / 235 102 / 235 103 / 235 104 / 235 105 / 235 106 / 235 107 / 235 108 / 235 109 / 235 110 / 235 111 / 235 112 / 235 113 / 235 114 / 235 115 / 235 116 / 235 117 / 235 118 / 235 119 / 235 120 / 235 121 / 235 122 / 235 123 / 235 124 / 235 125 / 235 126 / 235 127 / 235 128 / 235 129 / 235 130 / 235 131 / 235 132 / 235 133 / 235 134 / 235 135 / 235 136 / 235 137 / 235 138 / 235 139 / 235 140 / 235 141 / 235 142 / 235 143 / 235 144 / 235 145 / 235 146 / 235 147 / 235 148 / 235 149 / 235 150 / 235 151 / 235 152 / 235 153 / 235 154 / 235 155 / 235 156 / 235 157 / 235 158 / 235 159 / 235 160 / 235 161 / 235 162 / 235 163 / 235 164 / 235 165 / 235 166 / 235 167 / 235 168 / 235 169 / 235 170 / 235 171 / 235 172 / 235 173 / 235 174 / 235 175 / 235 176 / 235 177 / 235 178 / 235 179 / 235 180 / 235 181 / 235 182 / 235 183 / 235 184 / 235 185 / 235 186 / 235 187 / 235 188 / 235 189 / 235 190 / 235 191 / 235 192 / 235 193 / 235 194 / 235 195 / 235 196 / 235 197 / 235 198 / 235 199 / 235 200 / 235 201 / 235 202 / 235 203 / 235 204 / 235 205 / 235 206 / 235 207 / 235 208 / 235 209 / 235 210 / 235 211 / 235 212 / 235 213 / 235 214 / 235 215 / 235 216 / 235 217 / 235 218 / 235 219 / 235 220 / 235 221 / 235 222 / 235 223 / 235 224 / 235 225 / 235 226 / 235 227 / 235 228 / 235 229 / 235 230 / 235 231 / 235 232 / 235 233 / 235 234 / 235 test perplexity: 136.581

    8コアくらいで実行してみる

    %python3
    
    from common.optimizer import SGD
    from common.trainer import RnnlmTrainer
    from common.util import eval_perplexity
    from dataset import ptb
    
    # hyper parameters
    batch_size = 20
    wordvec_size = 100
    hidden_size = 100
    time_size = 35
    lr = 20.0
    max_epoch = 4
    max_grad = 0.25
    
    # load dataset
    corpus, word_to_id, id_to_word = ptb.load_data('train')
    corpus_test, _, _ = ptb.load_data('test')
    vocab_size = len(word_to_id)
    xs = corpus[:-1]
    ts = corpus[1:]
    
    # generate model
    model = Rnnlm(vocab_size, wordvec_size, hidden_size)
    optimizer = SGD(lr)
    trainer = RnnlmTrainer(model, optimizer)
    
    trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20)
    trainer.plot(ylim=(0, 500))
    
    model.reset_state()
    ppl_test = eval_perplexity(model, corpus_test)
    print('test perplexity: ', ppl_test)
    
    model.save_params()
    Downloading ptb.train.txt ... 
    Done
    Downloading ptb.test.txt ... 
    Done
    | epoch 1 |  iter 1 / 1327 | time 0[s] | perplexity 10000.10
    | epoch 1 |  iter 21 / 1327 | time 7[s] | perplexity 2852.60
    | epoch 1 |  iter 41 / 1327 | time 13[s] | perplexity 1197.79
    | epoch 1 |  iter 61 / 1327 | time 20[s] | perplexity 972.65
    | epoch 1 |  iter 81 / 1327 | time 26[s] | perplexity 801.79
    | epoch 1 |  iter 101 / 1327 | time 33[s] | perplexity 654.38
    | epoch 1 |  iter 121 / 1327 | time 40[s] | perplexity 650.02
    | epoch 1 |  iter 141 / 1327 | time 46[s] | perplexity 586.48
    | epoch 1 |  iter 161 / 1327 | time 53[s] | perplexity 575.72
    | epoch 1 |  iter 181 / 1327 | time 60[s] | perplexity 575.51
    | epoch 1 |  iter 201 / 1327 | time 66[s] | perplexity 500.62
    | epoch 1 |  iter 221 / 1327 | time 73[s] | perplexity 489.72
    | epoch 1 |  iter 241 / 1327 | time 79[s] | perplexity 440.72
    | epoch 1 |  iter 261 / 1327 | time 86[s] | perplexity 464.75
    | epoch 1 |  iter 281 / 1327 | time 93[s] | perplexity 458.81
    | epoch 1 |  iter 301 / 1327 | time 99[s] | perplexity 391.53
    | epoch 1 |  iter 321 / 1327 | time 106[s] | perplexity 352.07
    | epoch 1 |  iter 341 / 1327 | time 112[s] | perplexity 402.92
    | epoch 1 |  iter 361 / 1327 | time 119[s] | perplexity 408.88
    | epoch 1 |  iter 381 / 1327 | time 125[s] | perplexity 334.41
    | epoch 1 |  iter 401 / 1327 | time 132[s] | perplexity 356.17
    | epoch 1 |  iter 421 / 1327 | time 139[s] | perplexity 346.59
    | epoch 1 |  iter 441 / 1327 | time 145[s] | perplexity 325.49
    | epoch 1 |  iter 461 / 1327 | time 152[s] | perplexity 329.53
    | epoch 1 |  iter 481 / 1327 | time 159[s] | perplexity 312.91
    | epoch 1 |  iter 501 / 1327 | time 166[s] | perplexity 318.56
    | epoch 1 |  iter 521 / 1327 | time 172[s] | perplexity 304.97
    | epoch 1 |  iter 541 / 1327 | time 179[s] | perplexity 315.58
    | epoch 1 |  iter 561 / 1327 | time 186[s] | perplexity 285.87
    | epoch 1 |  iter 581 / 1327 | time 192[s] | perplexity 262.11
    | epoch 1 |  iter 601 / 1327 | time 199[s] | perplexity 336.64
    | epoch 1 |  iter 621 / 1327 | time 206[s] | perplexity 312.01
    | epoch 1 |  iter 641 / 1327 | time 212[s] | perplexity 286.62
    | epoch 1 |  iter 661 / 1327 | time 219[s] | perplexity 273.06
    | epoch 1 |  iter 681 / 1327 | time 226[s] | perplexity 227.11
    | epoch 1 |  iter 701 / 1327 | time 232[s] | perplexity 250.40
    | epoch 1 |  iter 721 / 1327 | time 239[s] | perplexity 262.76
    | epoch 1 |  iter 741 / 1327 | time 246[s] | perplexity 223.73
    | epoch 1 |  iter 761 / 1327 | time 252[s] | perplexity 232.59
    | epoch 1 |  iter 781 / 1327 | time 259[s] | perplexity 220.59
    | epoch 1 |  iter 801 / 1327 | time 266[s] | perplexity 244.55
    | epoch 1 |  iter 821 / 1327 | time 272[s] | perplexity 226.35
    | epoch 1 |  iter 841 / 1327 | time 279[s] | perplexity 229.13
    | epoch 1 |  iter 861 / 1327 | time 286[s] | perplexity 221.18
    | epoch 1 |  iter 881 / 1327 | time 293[s] | perplexity 208.49
    | epoch 1 |  iter 901 / 1327 | time 299[s] | perplexity 253.19
    | epoch 1 |  iter 921 / 1327 | time 306[s] | perplexity 230.20
    | epoch 1 |  iter 941 / 1327 | time 313[s] | perplexity 231.52
    | epoch 1 |  iter 961 / 1327 | time 319[s] | perplexity 244.30
    | epoch 1 |  iter 981 / 1327 | time 326[s] | perplexity 230.86
    | epoch 1 |  iter 1001 / 1327 | time 332[s] | perplexity 194.41
    | epoch 1 |  iter 1021 / 1327 | time 339[s] | perplexity 226.75
    | epoch 1 |  iter 1041 / 1327 | time 346[s] | perplexity 208.89
    | epoch 1 |  iter 1061 / 1327 | time 352[s] | perplexity 197.86
    | epoch 1 |  iter 1081 / 1327 | time 359[s] | perplexity 170.05
    | epoch 1 |  iter 1101 / 1327 | time 366[s] | perplexity 194.50
    | epoch 1 |  iter 1121 / 1327 | time 372[s] | perplexity 228.69
    | epoch 1 |  iter 1141 / 1327 | time 379[s] | perplexity 209.67
    | epoch 1 |  iter 1161 / 1327 | time 386[s] | perplexity 200.76
    | epoch 1 |  iter 1181 / 1327 | time 392[s] | perplexity 192.34
    | epoch 1 |  iter 1201 / 1327 | time 399[s] | perplexity 165.02
    | epoch 1 |  iter 1221 / 1327 | time 406[s] | perplexity 159.78
    | epoch 1 |  iter 1241 / 1327 | time 412[s] | perplexity 189.42
    | epoch 1 |  iter 1261 / 1327 | time 419[s] | perplexity 172.00
    | epoch 1 |  iter 1281 / 1327 | time 426[s] | perplexity 180.13
    | epoch 1 |  iter 1301 / 1327 | time 433[s] | perplexity 224.04
    | epoch 1 |  iter 1321 / 1327 | time 439[s] | perplexity 212.56
    | epoch 2 |  iter 1 / 1327 | time 442[s] | perplexity 227.27
    | epoch 2 |  iter 21 / 1327 | time 448[s] | perplexity 205.48
    | epoch 2 |  iter 41 / 1327 | time 455[s] | perplexity 192.19
    | epoch 2 |  iter 61 / 1327 | time 462[s] | perplexity 178.33
    | epoch 2 |  iter 81 / 1327 | time 469[s] | perplexity 160.51
    | epoch 2 |  iter 101 / 1327 | time 475[s] | perplexity 153.43
    | epoch 2 |  iter 121 / 1327 | time 482[s] | perplexity 161.21
    | epoch 2 |  iter 141 / 1327 | time 489[s] | perplexity 179.45
    | epoch 2 |  iter 161 / 1327 | time 496[s] | perplexity 194.39
    | epoch 2 |  iter 181 / 1327 | time 502[s] | perplexity 203.20
    | epoch 2 |  iter 201 / 1327 | time 509[s] | perplexity 187.49
    | epoch 2 |  iter 221 / 1327 | time 516[s] | perplexity 184.57
    | epoch 2 |  iter 241 / 1327 | time 522[s] | perplexity 179.09
    | epoch 2 |  iter 261 / 1327 | time 529[s] | perplexity 187.89
    | epoch 2 |  iter 281 / 1327 | time 536[s] | perplexity 186.25
    | epoch 2 |  iter 301 / 1327 | time 543[s] | perplexity 167.91
    | epoch 2 |  iter 321 / 1327 | time 549[s] | perplexity 138.67
    | epoch 2 |  iter 341 / 1327 | time 556[s] | perplexity 172.03
    | epoch 2 |  iter 361 / 1327 | time 563[s] | perplexity 198.36
    | epoch 2 |  iter 381 / 1327 | time 570[s] | perplexity 155.07
    | epoch 2 |  iter 401 / 1327 | time 576[s] | perplexity 169.08
    | epoch 2 |  iter 421 / 1327 | time 583[s] | perplexity 158.41
    | epoch 2 |  iter 441 / 1327 | time 590[s] | perplexity 165.13
    | epoch 2 |  iter 461 / 1327 | time 596[s] | perplexity 159.82
    | epoch 2 |  iter 481 / 1327 | time 603[s] | perplexity 158.04
    | epoch 2 |  iter 501 / 1327 | time 610[s] | perplexity 171.63
    | epoch 2 |  iter 521 / 1327 | time 616[s] | perplexity 173.55
    | epoch 2 |  iter 541 / 1327 | time 623[s] | perplexity 176.24
    | epoch 2 |  iter 561 / 1327 | time 630[s] | perplexity 156.97
    | epoch 2 |  iter 581 / 1327 | time 637[s] | perplexity 139.85
    | epoch 2 |  iter 601 / 1327 | time 643[s] | perplexity 192.27
    | epoch 2 |  iter 621 / 1327 | time 651[s] | perplexity 183.50
    | epoch 2 |  iter 641 / 1327 | time 658[s] | perplexity 167.28
    | epoch 2 |  iter 661 / 1327 | time 664[s] | perplexity 154.80
    | epoch 2 |  iter 681 / 1327 | time 671[s] | perplexity 131.82
    | epoch 2 |  iter 701 / 1327 | time 678[s] | perplexity 152.01
    | epoch 2 |  iter 721 / 1327 | time 685[s] | perplexity 161.30
    | epoch 2 |  iter 741 / 1327 | time 692[s] | perplexity 136.03
    | epoch 2 |  iter 761 / 1327 | time 698[s] | perplexity 131.99
    | epoch 2 |  iter 781 / 1327 | time 705[s] | perplexity 137.36
    | epoch 2 |  iter 801 / 1327 | time 712[s] | perplexity 148.77
    | epoch 2 |  iter 821 / 1327 | time 718[s] | perplexity 146.65
    | epoch 2 |  iter 841 / 1327 | time 725[s] | perplexity 145.19
    | epoch 2 |  iter 861 / 1327 | time 732[s] | perplexity 147.84
    | epoch 2 |  iter 881 / 1327 | time 739[s] | perplexity 131.78
    | epoch 2 |  iter 901 / 1327 | time 745[s] | perplexity 167.73
    | epoch 2 |  iter 921 / 1327 | time 752[s] | perplexity 148.08
    | epoch 2 |  iter 941 / 1327 | time 759[s] | perplexity 156.33
    | epoch 2 |  iter 961 / 1327 | time 765[s] | perplexity 166.75
    | epoch 2 |  iter 981 / 1327 | time 772[s] | perplexity 155.06
    | epoch 2 |  iter 1001 / 1327 | time 779[s] | perplexity 132.72
    | epoch 2 |  iter 1021 / 1327 | time 785[s] | perplexity 156.88
    | epoch 2 |  iter 1041 / 1327 | time 792[s] | perplexity 143.86
    | epoch 2 |  iter 1061 / 1327 | time 799[s] | perplexity 131.46
    | epoch 2 |  iter 1081 / 1327 | time 805[s] | perplexity 112.57
    | epoch 2 |  iter 1101 / 1327 | time 812[s] | perplexity 122.22
    | epoch 2 |  iter 1121 / 1327 | time 819[s] | perplexity 156.04
    | epoch 2 |  iter 1141 / 1327 | time 826[s] | perplexity 143.46
    | epoch 2 |  iter 1161 / 1327 | time 833[s] | perplexity 135.22
    | epoch 2 |  iter 1181 / 1327 | time 840[s] | perplexity 136.76
    | epoch 2 |  iter 1201 / 1327 | time 846[s] | perplexity 114.39
    | epoch 2 |  iter 1221 / 1327 | time 853[s] | perplexity 110.84
    | epoch 2 |  iter 1241 / 1327 | time 860[s] | perplexity 131.14
    | epoch 2 |  iter 1261 / 1327 | time 867[s] | perplexity 124.96
    | epoch 2 |  iter 1281 / 1327 | time 874[s] | perplexity 124.53
    | epoch 2 |  iter 1301 / 1327 | time 881[s] | perplexity 159.41
    | epoch 2 |  iter 1321 / 1327 | time 887[s] | perplexity 155.05
    | epoch 3 |  iter 1 / 1327 | time 890[s] | perplexity 162.31
    | epoch 3 |  iter 21 / 1327 | time 897[s] | perplexity 145.83
    | epoch 3 |  iter 41 / 1327 | time 904[s] | perplexity 137.93
    | epoch 3 |  iter 61 / 1327 | time 910[s] | perplexity 129.50
    | epoch 3 |  iter 81 / 1327 | time 917[s] | perplexity 118.11
    | epoch 3 |  iter 101 / 1327 | time 924[s] | perplexity 106.62
    | epoch 3 |  iter 121 / 1327 | time 931[s] | perplexity 116.58
    | epoch 3 |  iter 141 / 1327 | time 937[s] | perplexity 128.34
    | epoch 3 |  iter 161 / 1327 | time 944[s] | perplexity 145.42
    | epoch 3 |  iter 181 / 1327 | time 951[s] | perplexity 152.55
    | epoch 3 |  iter 201 / 1327 | time 958[s] | perplexity 143.41
    | epoch 3 |  iter 221 / 1327 | time 964[s] | perplexity 142.29
    | epoch 3 |  iter 241 / 1327 | time 971[s] | perplexity 136.47
    | epoch 3 |  iter 261 / 1327 | time 978[s] | perplexity 141.66
    | epoch 3 |  iter 281 / 1327 | time 984[s] | perplexity 142.91
    | epoch 3 |  iter 301 / 1327 | time 991[s] | perplexity 125.87
    | epoch 3 |  iter 321 / 1327 | time 998[s] | perplexity 104.19
    | epoch 3 |  iter 341 / 1327 | time 1004[s] | perplexity 128.50
    | epoch 3 |  iter 361 / 1327 | time 1011[s] | perplexity 154.88
    | epoch 3 |  iter 381 / 1327 | time 1018[s] | perplexity 117.23
    | epoch 3 |  iter 401 / 1327 | time 1025[s] | perplexity 130.37
    | epoch 3 |  iter 421 / 1327 | time 1031[s] | perplexity 116.28
    | epoch 3 |  iter 441 / 1327 | time 1038[s] | perplexity 126.29
    | epoch 3 |  iter 461 / 1327 | time 1045[s] | perplexity 120.17
    | epoch 3 |  iter 481 / 1327 | time 1051[s] | perplexity 121.45
    | epoch 3 |  iter 501 / 1327 | time 1058[s] | perplexity 132.70
    | epoch 3 |  iter 521 / 1327 | time 1065[s] | perplexity 138.81
    | epoch 3 |  iter 541 / 1327 | time 1072[s] | perplexity 139.54
    | epoch 3 |  iter 561 / 1327 | time 1078[s] | perplexity 120.78
    | epoch 3 |  iter 581 / 1327 | time 1085[s] | perplexity 106.74
    | epoch 3 |  iter 601 / 1327 | time 1092[s] | perplexity 151.46
    | epoch 3 |  iter 621 / 1327 | time 1099[s] | perplexity 146.41
    | epoch 3 |  iter 641 / 1327 | time 1105[s] | perplexity 132.30
    | epoch 3 |  iter 661 / 1327 | time 1112[s] | perplexity 120.47
    | epoch 3 |  iter 681 / 1327 | time 1119[s] | perplexity 101.10
    | epoch 3 |  iter 701 / 1327 | time 1125[s] | perplexity 120.23
    | epoch 3 |  iter 721 / 1327 | time 1132[s] | perplexity 127.58
    | epoch 3 |  iter 741 / 1327 | time 1139[s] | perplexity 110.48
    | epoch 3 |  iter 761 / 1327 | time 1146[s] | perplexity 104.48
    | epoch 3 |  iter 781 / 1327 | time 1152[s] | perplexity 106.76
    | epoch 3 |  iter 801 / 1327 | time 1159[s] | perplexity 116.66
    | epoch 3 |  iter 821 / 1327 | time 1166[s] | perplexity 120.04
    | epoch 3 |  iter 841 / 1327 | time 1173[s] | perplexity 115.68
    | epoch 3 |  iter 861 / 1327 | time 1179[s] | perplexity 122.05
    | epoch 3 |  iter 881 / 1327 | time 1186[s] | perplexity 108.33
    | epoch 3 |  iter 901 / 1327 | time 1193[s] | perplexity 134.20
    | epoch 3 |  iter 921 / 1327 | time 1200[s] | perplexity 120.01
    | epoch 3 |  iter 941 / 1327 | time 1206[s] | perplexity 130.07
    | epoch 3 |  iter 961 / 1327 | time 1213[s] | perplexity 133.97
    | epoch 3 |  iter 981 / 1327 | time 1220[s] | perplexity 125.72
    | epoch 3 |  iter 1001 / 1327 | time 1227[s] | perplexity 110.70
    | epoch 3 |  iter 1021 / 1327 | time 1233[s] | perplexity 128.98
    | epoch 3 |  iter 1041 / 1327 | time 1240[s] | perplexity 121.18
    | epoch 3 |  iter 1061 / 1327 | time 1247[s] | perplexity 104.08
    | epoch 3 |  iter 1081 / 1327 | time 1254[s] | perplexity 90.11
    | epoch 3 |  iter 1101 / 1327 | time 1260[s] | perplexity 96.71
    | epoch 3 |  iter 1121 / 1327 | time 1267[s] | perplexity 123.43
    | epoch 3 |  iter 1141 / 1327 | time 1274[s] | perplexity 116.54
    | epoch 3 |  iter 1161 / 1327 | time 1281[s] | perplexity 107.55
    | epoch 3 |  iter 1181 / 1327 | time 1287[s] | perplexity 113.06
    | epoch 3 |  iter 1201 / 1327 | time 1294[s] | perplexity 95.29
    | epoch 3 |  iter 1221 / 1327 | time 1301[s] | perplexity 89.38
    | epoch 3 |  iter 1241 / 1327 | time 1308[s] | perplexity 106.44
    | epoch 3 |  iter 1261 / 1327 | time 1314[s] | perplexity 106.48
    | epoch 3 |  iter 1281 / 1327 | time 1321[s] | perplexity 101.53
    | epoch 3 |  iter 1301 / 1327 | time 1328[s] | perplexity 130.48
    | epoch 3 |  iter 1321 / 1327 | time 1335[s] | perplexity 129.24
    | epoch 4 |  iter 1 / 1327 | time 1337[s] | perplexity 137.08
    | epoch 4 |  iter 21 / 1327 | time 1344[s] | perplexity 123.39
    | epoch 4 |  iter 41 / 1327 | time 1350[s] | perplexity 109.46
    | epoch 4 |  iter 61 / 1327 | time 1357[s] | perplexity 109.92
    | epoch 4 |  iter 81 / 1327 | time 1364[s] | perplexity 98.42
    | epoch 4 |  iter 101 / 1327 | time 1371[s] | perplexity 87.49
    | epoch 4 |  iter 121 / 1327 | time 1377[s] | perplexity 96.47
    | epoch 4 |  iter 141 / 1327 | time 1384[s] | perplexity 104.20
    | epoch 4 |  iter 161 / 1327 | time 1391[s] | perplexity 120.35
    | epoch 4 |  iter 181 / 1327 | time 1397[s] | perplexity 131.39
    | epoch 4 |  iter 201 / 1327 | time 1404[s] | perplexity 122.30
    | epoch 4 |  iter 221 / 1327 | time 1411[s] | perplexity 124.78
    | epoch 4 |  iter 241 / 1327 | time 1418[s] | perplexity 117.55
    | epoch 4 |  iter 261 / 1327 | time 1424[s] | perplexity 116.76
    | epoch 4 |  iter 281 / 1327 | time 1431[s] | perplexity 122.54
    | epoch 4 |  iter 301 / 1327 | time 1438[s] | perplexity 105.78
    | epoch 4 |  iter 321 / 1327 | time 1444[s] | perplexity 85.61
    | epoch 4 |  iter 341 / 1327 | time 1451[s] | perplexity 103.02
    | epoch 4 |  iter 361 / 1327 | time 1458[s] | perplexity 128.83
    | epoch 4 |  iter 381 / 1327 | time 1464[s] | perplexity 99.55
    | epoch 4 |  iter 401 / 1327 | time 1471[s] | perplexity 111.82
    | epoch 4 |  iter 421 / 1327 | time 1478[s] | perplexity 95.68
    | epoch 4 |  iter 441 / 1327 | time 1485[s] | perplexity 104.78
    | epoch 4 |  iter 461 / 1327 | time 1492[s] | perplexity 102.15
    | epoch 4 |  iter 481 / 1327 | time 1499[s] | perplexity 104.22
    | epoch 4 |  iter 501 / 1327 | time 1506[s] | perplexity 111.26
    | epoch 4 |  iter 521 / 1327 | time 1512[s] | perplexity 119.01
    | epoch 4 |  iter 541 / 1327 | time 1519[s] | perplexity 115.68
    | epoch 4 |  iter 561 / 1327 | time 1526[s] | perplexity 105.81
    | epoch 4 |  iter 581 / 1327 | time 1533[s] | perplexity 89.73
    | epoch 4 |  iter 601 / 1327 | time 1540[s] | perplexity 127.60
    | epoch 4 |  iter 621 / 1327 | time 1547[s] | perplexity 124.38
    | epoch 4 |  iter 641 / 1327 | time 1554[s] | perplexity 112.94
    | epoch 4 |  iter 661 / 1327 | time 1560[s] | perplexity 104.02
    | epoch 4 |  iter 681 / 1327 | time 1567[s] | perplexity 85.23
    | epoch 4 |  iter 701 / 1327 | time 1574[s] | perplexity 104.08
    | epoch 4 |  iter 721 / 1327 | time 1581[s] | perplexity 108.60
    | epoch 4 |  iter 741 / 1327 | time 1587[s] | perplexity 98.33
    | epoch 4 |  iter 761 / 1327 | time 1594[s] | perplexity 89.41
    | epoch 4 |  iter 781 / 1327 | time 1601[s] | perplexity 88.64
    | epoch 4 |  iter 801 / 1327 | time 1608[s] | perplexity 99.77
    | epoch 4 |  iter 821 / 1327 | time 1614[s] | perplexity 105.46
    | epoch 4 |  iter 841 / 1327 | time 1621[s] | perplexity 100.29
    | epoch 4 |  iter 861 / 1327 | time 1628[s] | perplexity 106.53
    | epoch 4 |  iter 881 / 1327 | time 1634[s] | perplexity 92.96
    | epoch 4 |  iter 901 / 1327 | time 1641[s] | perplexity 118.05
    | epoch 4 |  iter 921 / 1327 | time 1648[s] | perplexity 103.90
    | epoch 4 |  iter 941 / 1327 | time 1654[s] | perplexity 113.97
    | epoch 4 |  iter 961 / 1327 | time 1661[s] | perplexity 114.04
    | epoch 4 |  iter 981 / 1327 | time 1668[s] | perplexity 108.00
    | epoch 4 |  iter 1001 / 1327 | time 1675[s] | perplexity 98.89
    | epoch 4 |  iter 1021 / 1327 | time 1681[s] | perplexity 113.60
    | epoch 4 |  iter 1041 / 1327 | time 1688[s] | perplexity 106.50
    | epoch 4 |  iter 1061 / 1327 | time 1695[s] | perplexity 89.44
    | epoch 4 |  iter 1081 / 1327 | time 1702[s] | perplexity 79.99
    | epoch 4 |  iter 1101 / 1327 | time 1708[s] | perplexity 81.84
    | epoch 4 |  iter 1121 / 1327 | time 1715[s] | perplexity 104.83
    | epoch 4 |  iter 1141 / 1327 | time 1722[s] | perplexity 102.74
    | epoch 4 |  iter 1161 / 1327 | time 1728[s] | perplexity 93.97
    | epoch 4 |  iter 1181 / 1327 | time 1735[s] | perplexity 98.14
    | epoch 4 |  iter 1201 / 1327 | time 1742[s] | perplexity 84.93
    | epoch 4 |  iter 1221 / 1327 | time 1749[s] | perplexity 76.64
    | epoch 4 |  iter 1241 / 1327 | time 1755[s] | perplexity 93.11
    | epoch 4 |  iter 1261 / 1327 | time 1762[s] | perplexity 95.51
    | epoch 4 |  iter 1281 / 1327 | time 1769[s] | perplexity 90.25
    | epoch 4 |  iter 1301 / 1327 | time 1776[s] | perplexity 112.19
    | epoch 4 |  iter 1321 / 1327 | time 1782[s] | perplexity 111.55
    evaluating perplexity ...
    
    0 / 235
    1 / 235
    2 / 235
    3 / 235
    4 / 235
    5 / 235
    6 / 235
    7 / 235
    8 / 235
    9 / 235
    10 / 235
    11 / 235
    12 / 235
    13 / 235
    14 / 235
    15 / 235
    16 / 235
    17 / 235
    18 / 235
    19 / 235
    20 / 235
    21 / 235
    22 / 235
    23 / 235
    24 / 235
    25 / 235
    26 / 235
    27 / 235
    28 / 235
    29 / 235
    30 / 235
    31 / 235
    32 / 235
    33 / 235
    34 / 235
    35 / 235
    36 / 235
    37 / 235
    38 / 235
    39 / 235
    40 / 235
    41 / 235
    42 / 235
    43 / 235
    44 / 235
    45 / 235
    46 / 235
    47 / 235
    48 / 235
    49 / 235
    50 / 235
    51 / 235
    52 / 235
    53 / 235
    54 / 235
    55 / 235
    56 / 235
    57 / 235
    58 / 235
    59 / 235
    60 / 235
    61 / 235
    62 / 235
    63 / 235
    64 / 235
    65 / 235
    66 / 235
    67 / 235
    68 / 235
    69 / 235
    70 / 235
    71 / 235
    72 / 235
    73 / 235
    74 / 235
    75 / 235
    76 / 235
    77 / 235
    78 / 235
    79 / 235
    80 / 235
    81 / 235
    82 / 235
    83 / 235
    84 / 235
    85 / 235
    86 / 235
    87 / 235
    88 / 235
    89 / 235
    90 / 235
    91 / 235
    92 / 235
    93 / 235
    94 / 235
    95 / 235
    96 / 235
    97 / 235
    98 / 235
    99 / 235
    100 / 235
    101 / 235
    102 / 235
    103 / 235
    104 / 235
    105 / 235
    106 / 235
    107 / 235
    108 / 235
    109 / 235
    110 / 235
    111 / 235
    112 / 235
    113 / 235
    114 / 235
    115 / 235
    116 / 235
    117 / 235
    118 / 235
    119 / 235
    120 / 235
    121 / 235
    122 / 235
    123 / 235
    124 / 235
    125 / 235
    126 / 235
    127 / 235
    128 / 235
    129 / 235
    130 / 235
    131 / 235
    132 / 235
    133 / 235
    134 / 235
    135 / 235
    136 / 235
    137 / 235
    138 / 235
    139 / 235
    140 / 235
    141 / 235
    142 / 235
    143 / 235
    144 / 235
    145 / 235
    146 / 235
    147 / 235
    148 / 235
    149 / 235
    150 / 235
    151 / 235
    152 / 235
    153 / 235
    154 / 235
    155 / 235
    156 / 235
    157 / 235
    158 / 235
    159 / 235
    160 / 235
    161 / 235
    162 / 235
    163 / 235
    164 / 235
    165 / 235
    166 / 235
    167 / 235
    168 / 235
    169 / 235
    170 / 235
    171 / 235
    172 / 235
    173 / 235
    174 / 235
    175 / 235
    176 / 235
    177 / 235
    178 / 235
    179 / 235
    180 / 235
    181 / 235
    182 / 235
    183 / 235
    184 / 235
    185 / 235
    186 / 235
    187 / 235
    188 / 235
    189 / 235
    190 / 235
    191 / 235
    192 / 235
    193 / 235
    194 / 235
    195 / 235
    196 / 235
    197 / 235
    198 / 235
    199 / 235
    200 / 235
    201 / 235
    202 / 235
    203 / 235
    204 / 235
    205 / 235
    206 / 235
    207 / 235
    208 / 235
    209 / 235
    210 / 235
    211 / 235
    212 / 235
    213 / 235
    214 / 235
    215 / 235
    216 / 235
    217 / 235
    218 / 235
    219 / 235
    220 / 235
    221 / 235
    222 / 235
    223 / 235
    224 / 235
    225 / 235
    226 / 235
    227 / 235
    228 / 235
    229 / 235
    230 / 235
    231 / 235
    232 / 235
    233 / 235
    234 / 235
    test perplexity:  137.9052746711987
    

    %sh
    top -b -n1
    top - 11:21:03 up 14 min,  1 user,  load average: 7.33, 3.84, 1.59
    Tasks: 147 total,   2 running,  92 sleeping,   0 stopped,   0 zombie
    Cpu(s): 11.4%us, 11.9%sy,  0.0%ni, 76.1%id,  0.1%wa,  0.0%hi,  0.0%si,  0.5%st
    Mem:  15394144k total,  3463012k used, 11931132k free,    34120k buffers
    Swap:        0k total,        0k used,        0k free,  2005276k cached
    
      PID USER      PR  NI  VIRT  RES  SHR S %CPU %MEM    TIME+  COMMAND            
     3676 ec2-user  20   0 1058m 408m  18m R 655.4  2.7  25:58.79 python3           
     2604 root      20   0  6520   96    0 S  2.0  0.0   0:00.22 rngd               
     3403 ec2-user  20   0 6695m 557m  25m S  2.0  3.7   0:31.42 java               
     3483 ec2-user  20   0 4869m  91m  18m S  2.0  0.6   0:02.16 java               
     3634 ec2-user  20   0 5195m 115m  18m S  2.0  0.8   0:02.97 java               
     3782 ec2-user  20   0 15360 2292 1960 R  2.0  0.0   0:00.01 top                
        1 root      20   0 19688 2544 2212 S  0.0  0.0   0:01.17 init               
        2 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kthreadd           
        3 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/0:0        
        4 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/0:0H       
        6 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 mm_percpu_wq       
        7 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/0        
        8 root      20   0     0    0    0 I  0.0  0.0   0:00.09 rcu_sched          
        9 root      20   0     0    0    0 I  0.0  0.0   0:00.00 rcu_bh             
       10 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/0        
       11 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/0         
       12 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/0            
       13 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/1            
       14 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/1         
       15 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/1        
       16 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/1        
       17 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/1:0        
       18 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/1:0H       
       19 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/2            
       20 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/2         
       21 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/2        
       22 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/2        
       23 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/2:0        
       24 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/2:0H       
       25 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/3            
       26 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/3         
       27 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/3        
       28 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/3        
       29 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/3:0        
       30 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/3:0H       
       31 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/4            
       32 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/4         
       33 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/4        
       34 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/4        
       35 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/4:0        
       36 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/4:0H       
       37 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/5            
       38 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/5         
       39 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/5        
       40 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/5        
       41 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/5:0        
       42 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/5:0H       
       43 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/6            
       44 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/6         
       45 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/6        
       46 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/6        
       47 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/6:0        
       48 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/6:0H       
       49 root      20   0     0    0    0 S  0.0  0.0   0:00.00 cpuhp/7            
       50 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 watchdog/7         
       51 root      RT   0     0    0    0 S  0.0  0.0   0:00.00 migration/7        
       52 root      20   0     0    0    0 S  0.0  0.0   0:00.00 ksoftirqd/7        
       53 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/7:0        
       54 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/7:0H       
       55 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kdevtmpfs          
       56 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 netns              
       57 root      20   0     0    0    0 I  0.0  0.0   0:00.47 kworker/u30:1      
       63 root      20   0     0    0    0 S  0.0  0.0   0:00.00 xenbus             
       64 root      20   0     0    0    0 S  0.0  0.0   0:00.01 xenwatch           
       65 root      20   0     0    0    0 I  0.0  0.0   0:00.02 kworker/0:1        
       74 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/2:1        
       75 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/3:1        
       99 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/5:1        
      107 root      20   0     0    0    0 I  0.0  0.0   0:00.22 kworker/u30:5      
      161 root      20   0     0    0    0 I  0.0  0.0   0:00.05 kworker/7:1        
      218 root      20   0     0    0    0 S  0.0  0.0   0:00.00 khungtaskd         
      219 root      20   0     0    0    0 S  0.0  0.0   0:00.00 oom_reaper         
      220 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 writeback          
      221 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kcompactd0         
      223 root      25   5     0    0    0 S  0.0  0.0   0:00.00 ksmd               
      224 root      39  19     0    0    0 S  0.0  0.0   0:00.00 khugepaged         
      225 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 crypto             
      226 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kintegrityd        
      228 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kblockd            
      238 root      20   0     0    0    0 I  0.0  0.0   0:00.02 kworker/6:1        
      581 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 md                 
      584 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 edac-poller        
      590 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/4:1        
      728 root      20   0     0    0    0 S  0.0  0.0   0:00.02 kauditd            
      734 root      20   0     0    0    0 S  0.0  0.0   0:00.00 kswapd0            
      830 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kthrotld           
      890 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kstrp              
     1680 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ata_sff            
     1693 root      20   0     0    0    0 S  0.0  0.0   0:00.00 scsi_eh_0          
     1694 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 scsi_tmf_0         
     1697 root      20   0     0    0    0 S  0.0  0.0   0:00.00 scsi_eh_1          
     1698 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 scsi_tmf_1         
     1749 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/u31:0      
     1770 root      20   0     0    0    0 S  0.0  0.0   0:00.04 jbd2/xvda1-8       
     1771 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ext4-rsv-conver    
     1799 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/0:1H       
     1817 root      20   0 11476 2652 1752 S  0.0  0.0   0:00.07 udevd              
     1962 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ixgbevf            
     2076 root      20   0     0    0    0 I  0.0  0.0   0:00.01 kworker/1:2        
     2158 root      20   0  106m  620  388 S  0.0  0.0   0:00.00 lvmetad            
     2167 root      20   0 27196  200    4 S  0.0  0.0   0:00.00 lvmpolld           
     2219 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 ipv6_addrconf      
     2356 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/5:1H       
     2378 root      20   0  9412 1304  816 S  0.0  0.0   0:00.00 dhclient           
     2493 root      20   0  9412 1920 1444 S  0.0  0.0   0:00.00 dhclient           
     2535 root      20   0  615m  15m  11m S  0.0  0.1   0:00.09 amazon-ssm-agen    
     2547 root      16  -4 53004 2408 2024 S  0.0  0.0   0:00.00 auditd             
     2569 root      20   0  241m 2828 2444 S  0.0  0.0   0:00.01 rsyslogd           
     2590 root      20   0 92900 2576 2312 S  0.0  0.0   0:00.06 irqbalance         
     2622 rpc       20   0 35364 2264 1868 S  0.0  0.0   0:00.01 rpcbind            
     2643 rpcuser   20   0 39932 3300 2500 S  0.0  0.0   0:00.00 rpc.statd          
     2674 dbus      20   0 21844  232    0 S  0.0  0.0   0:00.00 dbus-daemon        
     2709 root      20   0  4396 1400 1260 S  0.0  0.0   0:00.00 acpid              
     2805 root      20   0 80588 2664 1840 S  0.0  0.0   0:00.00 sshd               
     2816 ntp       20   0  113m 5264 4456 S  0.0  0.0   0:00.06 ntpd               
     2837 root      20   0 89628 3888 2168 S  0.0  0.0   0:00.01 sendmail           
     2846 smmsp     20   0 81088 3904 2424 S  0.0  0.0   0:00.00 sendmail           
     2858 root      20   0  118m 2460 1864 S  0.0  0.0   0:00.00 crond              
     2872 root      20   0 19188  168    0 S  0.0  0.0   0:00.00 atd                
     2895 root      20   0  6508 1768 1644 S  0.0  0.0   0:00.02 agetty             
     2896 root      20   0  4360 1504 1408 S  0.0  0.0   0:00.02 mingetty           
     2900 root      20   0  4360 1504 1408 S  0.0  0.0   0:00.00 mingetty           
     2903 root      20   0  4360 1464 1372 S  0.0  0.0   0:00.00 mingetty           
     2905 root      20   0  4360 1452 1360 S  0.0  0.0   0:00.00 mingetty           
     2907 root      20   0  4360 1456 1360 S  0.0  0.0   0:00.00 mingetty           
     2909 root      20   0  4360 1492 1400 S  0.0  0.0   0:00.00 mingetty           
     2919 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/2:1H       
     2999 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/7:1H       
     3077 root      20   0  975m  72m  44m S  0.0  0.5   0:01.51 dockerd            
     3091 root      20   0  629m  23m  18m S  0.0  0.2   0:02.61 docker-containe    
     3108 root      20   0 11472 2252 1344 S  0.0  0.0   0:00.00 udevd              
     3109 root      20   0 11472 2228 1328 S  0.0  0.0   0:00.00 udevd              
     3240 root      20   0  117m 7184 6096 S  0.0  0.0   0:00.00 sshd               
     3242 ec2-user  20   0  117m 4004 2920 S  0.0  0.0   0:00.16 sshd               
     3243 ec2-user  20   0  112m 3504 3088 S  0.0  0.0   0:00.03 bash               
     3321 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/6:1H       
     3378 ec2-user  20   0 19996 2528 2260 S  0.0  0.0   0:00.00 tmux               
     3380 ec2-user  20   0 22444 3096 2516 S  0.0  0.0   0:00.20 tmux               
     3381 ec2-user  20   0  112m 3440 3048 S  0.0  0.0   0:00.00 bash               
     3402 ec2-user  20   0  105m 2280 2048 S  0.0  0.0   0:00.00 make               
     3470 ec2-user  20   0  110m 3060 2808 S  0.0  0.0   0:00.00 interpreter.sh     
     3482 ec2-user  20   0  110m 2268 2016 S  0.0  0.0   0:00.00 interpreter.sh     
     3572 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/1:1H       
     3621 ec2-user  20   0  110m 3020 2768 S  0.0  0.0   0:00.00 interpreter.sh     
     3633 ec2-user  20   0  110m 2200 1948 S  0.0  0.0   0:00.00 interpreter.sh     
     3756 root       0 -20     0    0    0 I  0.0  0.0   0:00.00 kworker/4:1H       
     3758 root      20   0     0    0    0 I  0.0  0.0   0:00.00 kworker/u30:0      
    
    

    6.5.4 より良い RNNLM の実装

    • LSTM レイヤの多層化
    • Dropout を使用
    • 重み共有

    %python3
    import sys
    sys.path.append('/tmp/deep-learning-from-scratch-2')
    from common.time_layers import TimeEmbedding, TimeDropout, TimeAffine, TimeSoftmaxWithLoss
    from common.np import *
    from common.base_model import BaseModel
    
    
    class BetterRnnlm(BaseModel):
        def __init__(self, vocab_size=10000, wordvec_size=650, hidden_size=650, dropout_ratio=0.5):
            V, D, H = vocab_size, wordvec_size, hidden_size
            rn = np.random.randn
            
            embed_W = (rn(V, D) / 100).astype('f')
            lstm_Wx1 = (rn(D, 4*H) / np.sqrt(D)).astype('f')
            lstm_Wh1 = (rn(H, 4*H) / np.sqrt(H)).astype('f')
            lstm_b1 = np.zeros(4*H).astype('f')
            lstm_Wx2 = (rn(D, 4*H) / np.sqrt(D)).astype('f')
            lstm_Wh2 = (rn(H, 4*H) / np.sqrt(H)).astype('f')
            lstm_b2 = np.zeros(4*H).astype('f')
            affine_b = np.zeros(V).astype('f')
            
            self.layers = [
                TimeEmbedding(embed_W),
                TimeDropout(dropout_ratio),
                TimeLSTM(lstm_Wx1, lstm_Wh1, lstm_b1, stateful=True),
                TimeDropout(dropout_ratio),
                TimeLSTM(lstm_Wx2, lstm_Wh2, lstm_b2, stateful=True),
                TimeDropout(dropout_ratio),
                TimeAffine(embed_W.T, affine_b)
            ]
            self.loss_layer = TimeSoftmaxWithLoss()
            self.lstm_layers = [self.layers[2], self.layers[4]]
            self.drop_layers = [self.layers[1], self.layers[3], self.layers[5]]
            
            self.params, self.grads = [], []
            for layer in self.layers:
                self.params += layer.params
                self.grads += layer.grads
        
        def predict(self, xs, train_flg=False):
            for layer in self.drop_layers:
                layer.train_flg = train_flg
            for layer in self.layers:
                xs = layer.forward(xs)
            return xs
        
        def forward(self, xs, ts, train_flg=True):
            score = self.predict(xs, train_flg)
            loss = self.loss_layer.forward(score, ts)
            return loss
        
        def backward(self, dout=1):
            dout = self.loss_layer.backward(dout)
            for layer in reversed(self.layers):
                dout = layer.backward(dout)
            return dout
        
        def reset_state(self):
            for layer in self.lstm_layers:
                layer.reset_state()

    レイヤーが増えるとGPUのほうがいい?

    CUDA: バージョン確認

    %sh
    nvidia-smi
    nvcc --version
    Sun Nov 18 12:35:56 2018       
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |===============================+======================+======================|
    |   0  Tesla V100-SXM2...  On   | 00000000:00:1B.0 Off |                    0 |
    | N/A   34C    P0    25W / 300W |      0MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
    |   1  Tesla V100-SXM2...  On   | 00000000:00:1C.0 Off |                    0 |
    | N/A   33C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
    |   2  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
    | N/A   32C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
    |   3  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
    | N/A   33C    P0    27W / 300W |      0MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                       GPU Memory |
    |  GPU       PID   Type   Process name                             Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+
    nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2017 NVIDIA Corporation
    Built on Fri_Sep__1_21:08:03_CDT_2017
    Cuda compilation tools, release 9.0, V9.0.176
    

    cupyのインストール

    %sh
    sudo pip-3.6 install cupy-cuda90
    Collecting cupy-cuda90
      Downloading https://files.pythonhosted.org/packages/f7/46/0910fb6901fec52d4a77ff36378c82103dabc676b5f50d334e3784fd321f/cupy_cuda90-5.0.0-cp36-cp36m-manylinux1_x86_64.whl (262.7MB)
    Collecting fastrlock>=0.3 (from cupy-cuda90)
      Downloading https://files.pythonhosted.org/packages/b5/93/a7efbd39eac46c137500b37570c31dedc2d31a8ff4949fcb90bda5bc5f16/fastrlock-0.4-cp36-cp36m-manylinux1_x86_64.whl
    Requirement already satisfied: numpy>=1.9.0 in /usr/lib64/python3.6/dist-packages (from cupy-cuda90)
    Requirement already satisfied: six>=1.9.0 in /usr/lib/python3.6/dist-packages (from cupy-cuda90)
    Installing collected packages: fastrlock, cupy-cuda90
    Successfully installed cupy-cuda90-5.0.0 fastrlock-0.4
    You are using pip version 9.0.3, however version 18.1 is available.
    You should consider upgrading via the 'pip install --upgrade pip' command.
    

    %sh
    nvidia-smi
    Sun Nov 18 12:48:46 2018       
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |===============================+======================+======================|
    |   0  Tesla V100-SXM2...  On   | 00000000:00:1B.0 Off |                    0 |
    | N/A   37C    P0    58W / 300W |   2821MiB / 16160MiB |     10%      Default |
    +-------------------------------+----------------------+----------------------+
    |   1  Tesla V100-SXM2...  On   | 00000000:00:1C.0 Off |                    0 |
    | N/A   33C    P0    36W / 300W |     11MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
    |   2  Tesla V100-SXM2...  On   | 00000000:00:1D.0 Off |                    0 |
    | N/A   31C    P0    38W / 300W |     11MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
    |   3  Tesla V100-SXM2...  On   | 00000000:00:1E.0 Off |                    0 |
    | N/A   32C    P0    42W / 300W |     11MiB / 16160MiB |      0%      Default |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                       GPU Memory |
    |  GPU       PID   Type   Process name                             Usage      |
    |=============================================================================|
    |    0      5761      C   python3                                     2810MiB |
    +-----------------------------------------------------------------------------+
    

    • 競り負けて途中で落ちてしまった 🙃
    • 今回のケースは複数GPUは必要なさそう
    • メモリもだいぶ余っている