ゼロから作る Deep Learning 2/Attention

Posted on

8章: Attention

ゼロから作るDeep Learning (2)の読書メモです。7章では単純な足し算の予測を通して時系列データを別の時系列データに変換するための seq2seq という手法について学びました。8章では seq2seq をさらに強力にする注意機構(attention mechanism)という仕組みを導入していきます。

TL;DR

  • 前章で実装した seq2seq は Encoder の出力が一つの固定長ベクトル
    • 入力サイズに対してスケールできず状態が溢れる
  • Attention の導入
    • Encoder の各時刻における LSTM の隠れ状態をまとめて Decoder に渡す
    • 隠れ状態の個数に合わせて確率分布的な重要度を表す重みを持った数列を用意する
    • 重み付き和を使ってコンテキストベクトルを算出

参考実装

%sh
cd /tmp
git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2
Cloning into 'deep-learning-from-scratch-2'...

python のバージョン

%sh
python3 --version
Python 3.5.2

numpy と matplotlib を入れる

%sh
pip3 install numpy matplotlib
bash: pip3: command not found

%sh
apt install -y python3-pip
WARNING: apt does not have a stable CLI interface. Use with caution in scripts.

Reading package lists...
Building dependency tree...
Reading state information...
The following additional packages will be installed:
  libpython3-dev libpython3.5 libpython3.5-dev python3-dev
  python3-pkg-resources python3-setuptools python3-wheel python3.5-dev
Suggested packages:
  python-setuptools-doc
The following NEW packages will be installed:
  libpython3-dev libpython3.5 libpython3.5-dev python3-dev python3-pip
  python3-pkg-resources python3-setuptools python3-wheel python3.5-dev
0 upgraded, 9 newly installed, 0 to remove and 0 not upgraded.
Need to get 39.4 MB of archives.
After this operation, 60.3 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 libpython3.5 amd64 3.5.2-2ubuntu0~16.04.5 [1,360 kB]
Get:2 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 libpython3.5-dev amd64 3.5.2-2ubuntu0~16.04.5 [37.3 MB]
Get:3 http://archive.ubuntu.com/ubuntu xenial/main amd64 libpython3-dev amd64 3.5.1-3 [6,926 B]
Get:4 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 python3.5-dev amd64 3.5.2-2ubuntu0~16.04.5 [413 kB]
Get:5 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-dev amd64 3.5.1-3 [1,186 B]
Get:6 http://archive.ubuntu.com/ubuntu xenial-updates/universe amd64 python3-pip all 8.1.1-2ubuntu0.4 [109 kB]
Get:7 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-pkg-resources all 20.7.0-1 [79.0 kB]
Get:8 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-setuptools all 20.7.0-1 [88.0 kB]
Get:9 http://archive.ubuntu.com/ubuntu xenial/universe amd64 python3-wheel all 0.29.0-1 [48.1 kB]
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 9.)
debconf: falling back to frontend: Readline
Fetched 39.4 MB in 2s (17.1 MB/s)
Selecting previously unselected package libpython3.5:amd64.
(Reading database ... 
(Reading database ... 5%
(Reading database ... 10%
(Reading database ... 15%
(Reading database ... 20%
(Reading database ... 25%
(Reading database ... 30%
(Reading database ... 35%
(Reading database ... 40%
(Reading database ... 45%
(Reading database ... 50%
(Reading database ... 55%
(Reading database ... 60%
(Reading database ... 65%
(Reading database ... 70%
(Reading database ... 75%
(Reading database ... 80%
(Reading database ... 85%
(Reading database ... 90%
(Reading database ... 95%
(Reading database ... 100%
(Reading database ... 32914 files and directories currently installed.)
Preparing to unpack .../libpython3.5_3.5.2-2ubuntu0~16.04.5_amd64.deb ...
Unpacking libpython3.5:amd64 (3.5.2-2ubuntu0~16.04.5) ...
Selecting previously unselected package libpython3.5-dev:amd64.
Preparing to unpack .../libpython3.5-dev_3.5.2-2ubuntu0~16.04.5_amd64.deb ...
Unpacking libpython3.5-dev:amd64 (3.5.2-2ubuntu0~16.04.5) ...
Selecting previously unselected package libpython3-dev:amd64.
Preparing to unpack .../libpython3-dev_3.5.1-3_amd64.deb ...
Unpacking libpython3-dev:amd64 (3.5.1-3) ...
Selecting previously unselected package python3.5-dev.
Preparing to unpack .../python3.5-dev_3.5.2-2ubuntu0~16.04.5_amd64.deb ...
Unpacking python3.5-dev (3.5.2-2ubuntu0~16.04.5) ...
Selecting previously unselected package python3-dev.
Preparing to unpack .../python3-dev_3.5.1-3_amd64.deb ...
Unpacking python3-dev (3.5.1-3) ...
Selecting previously unselected package python3-pip.
Preparing to unpack .../python3-pip_8.1.1-2ubuntu0.4_all.deb ...
Unpacking python3-pip (8.1.1-2ubuntu0.4) ...
Selecting previously unselected package python3-pkg-resources.
Preparing to unpack .../python3-pkg-resources_20.7.0-1_all.deb ...
Unpacking python3-pkg-resources (20.7.0-1) ...
Selecting previously unselected package python3-setuptools.
Preparing to unpack .../python3-setuptools_20.7.0-1_all.deb ...
Unpacking python3-setuptools (20.7.0-1) ...
Selecting previously unselected package python3-wheel.
Preparing to unpack .../python3-wheel_0.29.0-1_all.deb ...
Unpacking python3-wheel (0.29.0-1) ...
Processing triggers for libc-bin (2.23-0ubuntu10) ...
Processing triggers for man-db (2.7.5-1) ...
Setting up libpython3.5:amd64 (3.5.2-2ubuntu0~16.04.5) ...
Setting up libpython3.5-dev:amd64 (3.5.2-2ubuntu0~16.04.5) ...
Setting up libpython3-dev:amd64 (3.5.1-3) ...
Setting up python3.5-dev (3.5.2-2ubuntu0~16.04.5) ...
Setting up python3-dev (3.5.1-3) ...
Setting up python3-pip (8.1.1-2ubuntu0.4) ...
Setting up python3-pkg-resources (20.7.0-1) ...
Setting up python3-setuptools (20.7.0-1) ...
Setting up python3-wheel (0.29.0-1) ...
Processing triggers for libc-bin (2.23-0ubuntu10) ...

リトライ

%sh
pip3 install numpy matplotlib
Collecting numpy
  Downloading https://files.pythonhosted.org/packages/ad/15/690c13ae714e156491392cdbdbf41b485d23c285aa698239a67f7cfc9e0a/numpy-1.16.1-cp35-cp35m-manylinux1_x86_64.whl (17.2MB)
Collecting matplotlib
  Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB)
Collecting kiwisolver>=1.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB)
Collecting python-dateutil>=2.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/de/0a/001be530836743d8be6c2d85069f46fecf84ac6c18c7f5fb8125ee11d854/pyparsing-2.3.1-py2.py3-none-any.whl (61kB)
Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib)
Collecting six>=1.5 (from python-dateutil>=2.1->matplotlib)
  Downloading https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl
Installing collected packages: numpy, kiwisolver, six, python-dateutil, cycler, pyparsing, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.16.1 pyparsing-2.3.1 python-dateutil-2.7.5 six-1.12.0
You are using pip version 8.1.1, however version 19.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

WeightSum レイヤ

Encoder から受け取った情報に対して重みを掛けて和を取ったものを流していく

%python3
class WeightSum:
    def __init__(self):
        self.params, self.grads = [], []
        self.cache = None
    
    def forward(self, hs, a):
        N, T, H = hs.shape
        
        ar = a.reshape(N, T, 1).repeat(H, axis=2)
        t = hs * ar
        c = np.sum(t, axis=1)
        
        self.cache = (hs, ar)
        return c
    
    def backward(self, dc):
        hs, ar = self.cache
        N, T, H = hs.shape
        
        dt = dc.reshape(N, 1, H).repeat(T, axis=1)
        dar = dt * hs
        dhs = dt * ar
        da = np.sum(dar, axis=2)
        
        return dhs, da

AttentionWeight レイヤ

Encoder から入力されたベクトルについて Decoder の LSTM レイヤの隠れ状態ベクトルとの類似度を内積で取って Softmax 関数を適用する

%python3
import sys
sys.path.append('/tmp/deep-learning-from-scratch-2')

%python3
import numpy as np
from common.layers import Softmax


class AttentionWeight:
    def __init__(self):
        self.params, self.grads = [], []
        self.softmax = Softmax()
        self.cache = None
        
    def forward(self, hs, h):
        N, T, H = hs.shape
        
        hr = h.reshape(N, 1, H).repeat(T, axis=1)
        t = hs * hr
        s = np.sum(t, axis=2)
        a = self.softmax.forward(s)
        
        self.cache = (hs, hr)
        return a
    
    def backward(self, da):
        hs, hr = self.cache
        N, T, H = hs.shape
        
        ds = self.softmax.backward(da)
        dt = ds.reshape(N, T, 1).repeat(H, axis=2)
        dhs = dt * hr
        dhr = dt * hs
        dh = np.sum(dhr, axis=1)
        
        return dhs, dh

Attention レイヤ

%python3
class Attention:
    def __init__(self):
        self.params, self.grads = [], []
        self.attention_weight_layer = AttentionWeight()
        self.weight_sum_layer = WeightSum()
        self.attention_weight = None
    
    def forward(self, hs, h):
        a = self.attention_weight_layer.forward(hs, h)
        out = self.weight_sum_layer.forward(hs, a)
        self.attention_weight = a
        return out
    
    def backward(self, dout):
        dhs0, da = self.weight_sum_layer.backward(dout)
        dhs1, dh = self.attention_weight_layer.backward(da)
        dhs = dhs0 + dhs1
        return dhs, dh

TimeAttention レイヤ

Encoder からの入力をそれぞれ受け取って計算

%python3
class TimeAttention:
    def __init__(self):
        self.params, self.grads = [], []
        self.layers = None
        self.attention_weights = None
    
    def forward(self, hs_enc, hs_dec):
        N, T, H = hs_dec.shape
        out = np.empty_like(hs_dec)
        self.layers = []
        self.attention_weights = []
        
        for t in range(T):
            layer = Attention()
            out[:, t, :] = layer.forward(hs_enc, hs_dec[:, t, :])
            self.layers.append(layer)
            self.attention_weights.append(layer.attention_weight)
        
        return out
    
    def backward(self, dout):
        N, T, H = dout.shape
        dhs_enc = 0
        dhs_dec = np.empty_like(dout)
        
        for t in range(T):
            layer = self.layers[t]
            dhs, dh = layer.backward(dout[:, t, :])
            dhs_enc += dhs
            dhs_dec[:, t, :] = dh
        
        return dhs_enc, dhs_dec

AttentionEncoder

前章の Encoder クラスとの変更点

  • forward メソッドで最後の隠れ状態ベクトルだけを返していたところを、すべての隠れ状態を返すように変更
  • backward メソッドも同様に修正する

%python3
class Encoder:
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V, D, H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        
        embed_W = (rn(V, D) / 100).astype('f')
        lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f')
        lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f')
        lstm_b = np.zeros(4 * H).astype('f')
        
        self.embed = TimeEmbedding(embed_W)
        self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=False)
        
        self.params = self.embed.params + self.lstm.params
        self.grads = self.embed.grads + self.lstm.grads
        self.hs = None
    
    def forward(self, xs):
        xs = self.embed.forward(xs)
        hs = self.lstm.forward(xs)
        self.hs = hs
        return hs[:, -1, :]
    
    def backward(self, dh):
        dhs = np.zeros_like(self.hs)
        dhs[:, -1, :] = dh
        
        dout = self.lstm.backward(dhs)
        dout = self.embed.backward(dout)
        return dout

%python3
from common.time_layers import *

class AttentionEncoder(Encoder):
    def forward(self, xs):
        xs = self.embed.forward(xs)
        hs = self.lstm.forward(xs)
        return hs
    
    def backward(self, dhs):
        dout = self.lstm.backward(dhs)
        dout = self.embed.backward(dout)
        return dout

AttentionDecoder

%python3
class AttentionDecoder:
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V, D, H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn
        
        embed_W = (rn(V, D) / 100).astype('f')
        lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f')
        lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f')
        lstm_b = np.zeros(4 * H).astype('f')
        affine_W = (rn(2*H, V) / np.sqrt(2*H)).astype('f')
        affine_b = np.zeros(V).astype('f')
        
        self.embed = TimeEmbedding(embed_W)
        self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=True)
        self.attention = TimeAttention()
        self.affine = TimeAffine(affine_W, affine_b)
        layers = [self.embed, self.lstm, self.attention, self.affine]
        
        self.params, self.grads = [], []
        for layer in layers:
            self.params += layer.params
            self.grads += layer.grads
        
    def forward(self, xs, enc_hs):
        h = enc_hs[:, -1]
        self.lstm.set_state(h)
        
        out = self.embed.forward(xs)
        dec_hs = self.lstm.forward(out)
        c = self.attention.forward(enc_hs, dec_hs)
        out = np.concatenate((c, dec_hs), axis=2)
        score = self.affine.forward(out)
        
        return score
    
    def backward(self, dscore):
        dout = self.affine.backward(dscore)
        N, T, H2 = dout.shape
        H = H2 // 2
        
        dc, ddec_hs0 = dout[:,:,:H], dout[:,:,H:]
        denc_hs, ddec_hs1 = self.attention.backward(dc)
        ddec_hs = ddec_hs0 + ddec_hs1
        dout = self.lstm.backward(ddec_hs)
        dh = self.lstm.dh
        denc_hs[:, -1] += dh
        self.embed.backward(dout)
        
        return denc_hs
        
    def generate(self, enc_hs, start_id, sample_size):
        sampled = []
        sample_id = start_id
        h = enc_hs[:, -1]
        self.lstm.set_state(h)
        
        for _ in range(sample_size):
            x = np.array([sample_id]).reshape((1, 1))
            
            out = self.embed.forward(x)
            dec_hs = self.lstm.forward(out)
            c = self.attention.forward(enc_hs, dec_hs)
            out = np.concatenate((c, dec_hs), axis=2)
            score = self.affine.forward(out)
            
            sample_id = np.argmax(score.flatten())
            sampled.append(sample_id)
        
        return sampled

AttentionSeq2Seq

Attention を使った Encoder と Decoder に差し替え

%python3
from common.base_model import BaseModel

class Seq2Seq(BaseModel):
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V, D, H = vocab_size, wordvec_size, hidden_size
        self.encoder = Encoder(V, D, H)
        self.decoder = Decoder(V, D, H)
        self.softmax = TimeSoftmaxWithLoss()
        
        self.params = self.encoder.params + self.decoder.params
        self.grads = self.encoder.grads + self.decoder.grads
    
    def forward(self, xs, ts):
        decoder_xs, decoder_ts = ts[:, :-1], ts[:, 1:]
        
        h = self.encoder.forward(xs)
        score = self.decoder.forward(decoder_xs, h)
        loss = self.softmax.forward(score, decoder_ts)
        return loss
    
    def backward(self, dout=1):
        dout = self.softmax.backward(dout)
        dh = self.decoder.backward(dout)
        dout = self.encoder.backward(dh)
        return dout
    
    def generate(self, xs, start_id, sample_size):
        h = self.encoder.forward(xs)
        sampled = self.decoder.generate(h, start_id, sample_size)
        return sampled

%python3
class AttentionSeq2Seq(Seq2Seq):
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        args = vocab_size, wordvec_size, hidden_size
        self.encoder = AttentionEncoder(*args)
        self.decoder = AttentionDecoder(*args)
        self.softmax = TimeSoftmaxWithLoss()
        
        self.params = self.encoder.params + self.decoder.params
        self.grads = self.encoder.grads + self.decoder.grads

Attention 付き Seq2Seq の学習

今回は足し算ではなく日付のフォーマット変換をやる

使用するデータセット

%sh
cat /tmp/deep-learning-from-scratch-2/dataset/date.txt | head
september 27, 1994           _1994-09-27
August 19, 2003              _2003-08-19
2/10/93                      _1993-02-10
10/31/90                     _1990-10-31
TUESDAY, SEPTEMBER 25, 1984  _1984-09-25
JUN 17, 2013                 _2013-06-17
april 3, 1996                _1996-04-03
October 24, 1974             _1974-10-24
AUGUST 11, 1986              _1986-08-11
February 16, 2015            _2015-02-16

%python3
from dataset import sequence
from common.optimizer import Adam
from common.trainer import Trainer
from common.util import eval_seq2seq

%python3
# データの読み込み
(x_train, t_train), (x_test, t_test) = sequence.load_data('date.txt')
char_to_id, id_to_char = sequence.get_vocab()

# 入力を反転
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]

%python3
# ハイパーパラメータの設定
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256
batch_size = 128
max_epoch = 10
max_grad = 5.0

%python3
model = AttentionSeq2Seq(vocab_size, wordvec_size, hidden_size)
optimizer = Adam()
trainer = Trainer(model, optimizer)

トレーニング実行

%python3
acc_list = []
for epoch in range(max_epoch):
    trainer.fit(x_train, t_train, max_epoch=1, batch_size=batch_size, max_grad=max_grad)
    
    correct_num = 0
    for i in range(len(x_test)):
        question, correct = x_test[[i]], t_test[[i]]
        verbose = i < 10
        correct_num += eval_seq2seq(model, question, correct, id_to_char, verbose, is_reverse=True)
    
    acc = float(correct_num) / len(x_test)
    acc_list.append(acc)
    print('val acc %.3f%%' % (acc * 100))
    
model.save_params()
| epoch 1 |  iter 1 / 351 | time 0[s] | loss 4.08
| epoch 1 |  iter 21 / 351 | time 12[s] | loss 3.09
| epoch 1 |  iter 41 / 351 | time 24[s] | loss 1.90
| epoch 1 |  iter 61 / 351 | time 37[s] | loss 1.72
| epoch 1 |  iter 81 / 351 | time 49[s] | loss 1.46
| epoch 1 |  iter 101 / 351 | time 62[s] | loss 1.19
| epoch 1 |  iter 121 / 351 | time 74[s] | loss 1.14
| epoch 1 |  iter 141 / 351 | time 87[s] | loss 1.09
| epoch 1 |  iter 161 / 351 | time 99[s] | loss 1.06
| epoch 1 |  iter 181 / 351 | time 111[s] | loss 1.04
| epoch 1 |  iter 201 / 351 | time 124[s] | loss 1.03
| epoch 1 |  iter 221 / 351 | time 136[s] | loss 1.02
| epoch 1 |  iter 241 / 351 | time 149[s] | loss 1.02
| epoch 1 |  iter 261 / 351 | time 161[s] | loss 1.01
| epoch 1 |  iter 281 / 351 | time 174[s] | loss 1.00
| epoch 1 |  iter 301 / 351 | time 186[s] | loss 1.00
| epoch 1 |  iter 321 / 351 | time 199[s] | loss 1.00
| epoch 1 |  iter 341 / 351 | time 211[s] | loss 1.00
Q 10/15/94                     
T 1994-10-15
☒ 1978-08-11
---
Q thursday, november 13, 2008  
T 2008-11-13
☒ 1978-08-11
---
Q Mar 25, 2003                 
T 2003-03-25
☒ 1978-08-11
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☒ 1978-08-11
---
Q Saturday, July 18, 1970      
T 1970-07-18
☒ 1978-08-11
---
Q october 6, 1992              
T 1992-10-06
☒ 1978-08-11
---
Q 8/23/08                      
T 2008-08-23
☒ 1978-08-11
---
Q 8/30/07                      
T 2007-08-30
☒ 1978-08-11
---
Q 10/28/13                     
T 2013-10-28
☒ 1978-08-11
---
Q sunday, november 6, 2016     
T 2016-11-06
☒ 1978-08-11
---
val acc 0.000%
| epoch 2 |  iter 1 / 351 | time 0[s] | loss 1.00
| epoch 2 |  iter 21 / 351 | time 13[s] | loss 1.00
| epoch 2 |  iter 41 / 351 | time 25[s] | loss 0.99
| epoch 2 |  iter 61 / 351 | time 38[s] | loss 0.99
| epoch 2 |  iter 81 / 351 | time 50[s] | loss 0.99
| epoch 2 |  iter 101 / 351 | time 63[s] | loss 0.99
| epoch 2 |  iter 121 / 351 | time 75[s] | loss 0.99
| epoch 2 |  iter 141 / 351 | time 88[s] | loss 0.98
| epoch 2 |  iter 161 / 351 | time 100[s] | loss 0.98
| epoch 2 |  iter 181 / 351 | time 113[s] | loss 0.97
| epoch 2 |  iter 201 / 351 | time 126[s] | loss 0.95
| epoch 2 |  iter 221 / 351 | time 138[s] | loss 0.94
| epoch 2 |  iter 241 / 351 | time 151[s] | loss 0.90
| epoch 2 |  iter 261 / 351 | time 163[s] | loss 0.83
| epoch 2 |  iter 281 / 351 | time 176[s] | loss 0.74
| epoch 2 |  iter 301 / 351 | time 188[s] | loss 0.66
| epoch 2 |  iter 321 / 351 | time 201[s] | loss 0.58
| epoch 2 |  iter 341 / 351 | time 213[s] | loss 0.46
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☒ 2006-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☒ 2007-08-09
---
Q 10/28/13                     
T 2013-10-28
☒ 1983-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☒ 2016-11-08
---
val acc 51.640%
| epoch 3 |  iter 1 / 351 | time 0[s] | loss 0.35
| epoch 3 |  iter 21 / 351 | time 13[s] | loss 0.30
| epoch 3 |  iter 41 / 351 | time 25[s] | loss 0.21
| epoch 3 |  iter 61 / 351 | time 38[s] | loss 0.14
| epoch 3 |  iter 81 / 351 | time 50[s] | loss 0.09
| epoch 3 |  iter 101 / 351 | time 63[s] | loss 0.07
| epoch 3 |  iter 121 / 351 | time 75[s] | loss 0.05
| epoch 3 |  iter 141 / 351 | time 88[s] | loss 0.04
| epoch 3 |  iter 161 / 351 | time 100[s] | loss 0.03
| epoch 3 |  iter 181 / 351 | time 113[s] | loss 0.03
| epoch 3 |  iter 201 / 351 | time 125[s] | loss 0.02
| epoch 3 |  iter 221 / 351 | time 138[s] | loss 0.02
| epoch 3 |  iter 241 / 351 | time 150[s] | loss 0.02
| epoch 3 |  iter 261 / 351 | time 163[s] | loss 0.01
| epoch 3 |  iter 281 / 351 | time 175[s] | loss 0.01
| epoch 3 |  iter 301 / 351 | time 188[s] | loss 0.01
| epoch 3 |  iter 321 / 351 | time 200[s] | loss 0.01
| epoch 3 |  iter 341 / 351 | time 213[s] | loss 0.01
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 99.900%
| epoch 4 |  iter 1 / 351 | time 0[s] | loss 0.01
| epoch 4 |  iter 21 / 351 | time 13[s] | loss 0.01
| epoch 4 |  iter 41 / 351 | time 25[s] | loss 0.01
| epoch 4 |  iter 61 / 351 | time 38[s] | loss 0.01
| epoch 4 |  iter 81 / 351 | time 50[s] | loss 0.01
| epoch 4 |  iter 101 / 351 | time 62[s] | loss 0.01
| epoch 4 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 4 |  iter 141 / 351 | time 87[s] | loss 0.01
| epoch 4 |  iter 161 / 351 | time 100[s] | loss 0.00
| epoch 4 |  iter 181 / 351 | time 112[s] | loss 0.00
| epoch 4 |  iter 201 / 351 | time 125[s] | loss 0.00
| epoch 4 |  iter 221 / 351 | time 137[s] | loss 0.00
| epoch 4 |  iter 241 / 351 | time 150[s] | loss 0.00
| epoch 4 |  iter 261 / 351 | time 162[s] | loss 0.00
| epoch 4 |  iter 281 / 351 | time 175[s] | loss 0.00
| epoch 4 |  iter 301 / 351 | time 187[s] | loss 0.00
| epoch 4 |  iter 321 / 351 | time 200[s] | loss 0.00
| epoch 4 |  iter 341 / 351 | time 212[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 99.900%
| epoch 5 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 5 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 5 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 5 |  iter 61 / 351 | time 38[s] | loss 0.00
| epoch 5 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 5 |  iter 101 / 351 | time 63[s] | loss 0.00
| epoch 5 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 5 |  iter 141 / 351 | time 88[s] | loss 0.00
| epoch 5 |  iter 161 / 351 | time 101[s] | loss 0.00
| epoch 5 |  iter 181 / 351 | time 113[s] | loss 0.00
| epoch 5 |  iter 201 / 351 | time 126[s] | loss 0.00
| epoch 5 |  iter 221 / 351 | time 138[s] | loss 0.00
| epoch 5 |  iter 241 / 351 | time 151[s] | loss 0.00
| epoch 5 |  iter 261 / 351 | time 163[s] | loss 0.00
| epoch 5 |  iter 281 / 351 | time 176[s] | loss 0.00
| epoch 5 |  iter 301 / 351 | time 188[s] | loss 0.00
| epoch 5 |  iter 321 / 351 | time 201[s] | loss 0.00
| epoch 5 |  iter 341 / 351 | time 213[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 99.920%
| epoch 6 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 6 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 6 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 6 |  iter 61 / 351 | time 38[s] | loss 0.00
| epoch 6 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 6 |  iter 101 / 351 | time 63[s] | loss 0.00
| epoch 6 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 6 |  iter 141 / 351 | time 88[s] | loss 0.00
| epoch 6 |  iter 161 / 351 | time 100[s] | loss 0.00
| epoch 6 |  iter 181 / 351 | time 113[s] | loss 0.00
| epoch 6 |  iter 201 / 351 | time 125[s] | loss 0.00
| epoch 6 |  iter 221 / 351 | time 138[s] | loss 0.00
| epoch 6 |  iter 241 / 351 | time 150[s] | loss 0.00
| epoch 6 |  iter 261 / 351 | time 163[s] | loss 0.00
| epoch 6 |  iter 281 / 351 | time 175[s] | loss 0.00
| epoch 6 |  iter 301 / 351 | time 188[s] | loss 0.00
| epoch 6 |  iter 321 / 351 | time 200[s] | loss 0.00
| epoch 6 |  iter 341 / 351 | time 213[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 99.940%
| epoch 7 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 7 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 7 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 7 |  iter 61 / 351 | time 38[s] | loss 0.00
| epoch 7 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 7 |  iter 101 / 351 | time 63[s] | loss 0.00
| epoch 7 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 7 |  iter 141 / 351 | time 88[s] | loss 0.00
| epoch 7 |  iter 161 / 351 | time 100[s] | loss 0.00
| epoch 7 |  iter 181 / 351 | time 113[s] | loss 0.00
| epoch 7 |  iter 201 / 351 | time 125[s] | loss 0.00
| epoch 7 |  iter 221 / 351 | time 138[s] | loss 0.00
| epoch 7 |  iter 241 / 351 | time 150[s] | loss 0.00
| epoch 7 |  iter 261 / 351 | time 163[s] | loss 0.00
| epoch 7 |  iter 281 / 351 | time 175[s] | loss 0.00
| epoch 7 |  iter 301 / 351 | time 188[s] | loss 0.02
| epoch 7 |  iter 321 / 351 | time 200[s] | loss 0.04
| epoch 7 |  iter 341 / 351 | time 213[s] | loss 0.01
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 99.180%
| epoch 8 |  iter 1 / 351 | time 0[s] | loss 0.01
| epoch 8 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 8 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 8 |  iter 61 / 351 | time 38[s] | loss 0.00
| epoch 8 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 8 |  iter 101 / 351 | time 63[s] | loss 0.00
| epoch 8 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 8 |  iter 141 / 351 | time 88[s] | loss 0.00
| epoch 8 |  iter 161 / 351 | time 100[s] | loss 0.00
| epoch 8 |  iter 181 / 351 | time 113[s] | loss 0.00
| epoch 8 |  iter 201 / 351 | time 125[s] | loss 0.00
| epoch 8 |  iter 221 / 351 | time 138[s] | loss 0.00
| epoch 8 |  iter 241 / 351 | time 150[s] | loss 0.00
| epoch 8 |  iter 261 / 351 | time 163[s] | loss 0.00
| epoch 8 |  iter 281 / 351 | time 175[s] | loss 0.00
| epoch 8 |  iter 301 / 351 | time 188[s] | loss 0.00
| epoch 8 |  iter 321 / 351 | time 201[s] | loss 0.00
| epoch 8 |  iter 341 / 351 | time 213[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 100.000%
| epoch 9 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 9 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 9 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 9 |  iter 61 / 351 | time 37[s] | loss 0.00
| epoch 9 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 9 |  iter 101 / 351 | time 62[s] | loss 0.00
| epoch 9 |  iter 121 / 351 | time 75[s] | loss 0.00
| epoch 9 |  iter 141 / 351 | time 87[s] | loss 0.00
| epoch 9 |  iter 161 / 351 | time 100[s] | loss 0.00
| epoch 9 |  iter 181 / 351 | time 112[s] | loss 0.00
| epoch 9 |  iter 201 / 351 | time 124[s] | loss 0.00
| epoch 9 |  iter 221 / 351 | time 137[s] | loss 0.00
| epoch 9 |  iter 241 / 351 | time 149[s] | loss 0.00
| epoch 9 |  iter 261 / 351 | time 162[s] | loss 0.00
| epoch 9 |  iter 281 / 351 | time 174[s] | loss 0.00
| epoch 9 |  iter 301 / 351 | time 187[s] | loss 0.00
| epoch 9 |  iter 321 / 351 | time 199[s] | loss 0.00
| epoch 9 |  iter 341 / 351 | time 212[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 100.000%
| epoch 10 |  iter 1 / 351 | time 0[s] | loss 0.00
| epoch 10 |  iter 21 / 351 | time 13[s] | loss 0.00
| epoch 10 |  iter 41 / 351 | time 25[s] | loss 0.00
| epoch 10 |  iter 61 / 351 | time 38[s] | loss 0.00
| epoch 10 |  iter 81 / 351 | time 50[s] | loss 0.00
| epoch 10 |  iter 101 / 351 | time 63[s] | loss 0.00
| epoch 10 |  iter 121 / 351 | time 76[s] | loss 0.00
| epoch 10 |  iter 141 / 351 | time 88[s] | loss 0.00
| epoch 10 |  iter 161 / 351 | time 101[s] | loss 0.00
| epoch 10 |  iter 181 / 351 | time 114[s] | loss 0.00
| epoch 10 |  iter 201 / 351 | time 126[s] | loss 0.00
| epoch 10 |  iter 221 / 351 | time 139[s] | loss 0.00
| epoch 10 |  iter 241 / 351 | time 151[s] | loss 0.00
| epoch 10 |  iter 261 / 351 | time 164[s] | loss 0.00
| epoch 10 |  iter 281 / 351 | time 176[s] | loss 0.00
| epoch 10 |  iter 301 / 351 | time 189[s] | loss 0.00
| epoch 10 |  iter 321 / 351 | time 202[s] | loss 0.00
| epoch 10 |  iter 341 / 351 | time 214[s] | loss 0.00
Q 10/15/94                     
T 1994-10-15
☑ 1994-10-15
---
Q thursday, november 13, 2008  
T 2008-11-13
☑ 2008-11-13
---
Q Mar 25, 2003                 
T 2003-03-25
☑ 2003-03-25
---
Q Tuesday, November 22, 2016   
T 2016-11-22
☑ 2016-11-22
---
Q Saturday, July 18, 1970      
T 1970-07-18
☑ 1970-07-18
---
Q october 6, 1992              
T 1992-10-06
☑ 1992-10-06
---
Q 8/23/08                      
T 2008-08-23
☑ 2008-08-23
---
Q 8/30/07                      
T 2007-08-30
☑ 2007-08-30
---
Q 10/28/13                     
T 2013-10-28
☑ 2013-10-28
---
Q sunday, november 6, 2016     
T 2016-11-06
☑ 2016-11-06
---
val acc 100.000%

%sh
top -b -n 1
top - 07:51:33 up 53 min,  0 users,  load average: 3.95, 2.30, 0.97
Tasks:  18 total,   2 running,  16 sleeping,   0 stopped,   0 zombie
%Cpu(s):  4.4 us,  4.1 sy,  0.0 ni, 90.4 id,  0.7 wa,  0.0 hi,  0.0 si,  0.3 st
KiB Mem : 16426336 total, 13631944 free,  1586288 used,  1208104 buff/cache
KiB Swap:        0 total,        0 free,        0 used. 14425956 avail Mem 

  PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME+ COMMAND
  573 root      20   0  596620 255420  17564 R 373.3  1.6  16:34.76 python3
    8 root      20   0 4659904 612268  25060 S  26.7  3.7   0:59.91 java
    1 root      20   0    4364    736    672 S   0.0  0.0   0:00.07 tini
   67 root      20   0   19768   3408   3084 S   0.0  0.0   0:00.00 interprete+
   79 root      20   0   19768   2352   2024 S   0.0  0.0   0:00.00 interprete+
   80 root      20   0 4542088 100484  19000 S   0.0  0.6   0:03.89 java
  524 root      20   0   19768   3416   3096 S   0.0  0.0   0:00.00 interprete+
  536 root      20   0   19768   2360   2036 S   0.0  0.0   0:00.00 interprete+
  537 root      20   0 4559460 126160  18468 S   0.0  0.8   0:04.89 java
  593 root      20   0   19768   3416   3092 S   0.0  0.0   0:00.00 interprete+
  605 root      20   0   19768   2360   2032 S   0.0  0.0   0:00.00 interprete+
  606 root      20   0 4636276 151476  18508 S   0.0  0.9   0:05.36 java
  644 root      20   0 1418844  46480  14084 S   0.0  0.3   0:02.03 python
  676 root      20   0  605696  44584  11276 S   0.0  0.3   0:01.12 python
  744 root      20   0   19768   3416   3100 S   0.0  0.0   0:00.00 interprete+
  756 root      20   0   19768   2268   1948 S   0.0  0.0   0:00.00 interprete+
  757 root      20   0 4076040 147440  17972 S   0.0  0.9   0:05.05 java
 1130 root      20   0   38164   3480   3108 R   0.0  0.0   0:00.00 top

%sh
top -b -n 1
top - 08:32:39 up  1:34,  0 users,  load average: 4.04, 4.03, 3.86
Tasks:  18 total,   2 running,  16 sleeping,   0 stopped,   0 zombie
%Cpu(s): 23.2 us, 24.7 sy,  0.0 ni, 51.5 id,  0.4 wa,  0.0 hi,  0.0 si,  0.2 st
KiB Mem : 16426336 total, 13396032 free,  1821236 used,  1209068 buff/cache
KiB Swap:        0 total,        0 free,        0 used. 14190540 avail Mem 

  PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME+ COMMAND
  573 root      20   0  625768 284436  17628 R 373.3  1.7 179:43.56 python3
    8 root      20   0 4664008 613632  25060 S   6.7  3.7   1:08.96 java
    1 root      20   0    4364    736    672 S   0.0  0.0   0:00.12 tini
   67 root      20   0   19768   3408   3084 S   0.0  0.0   0:00.00 interprete+
   79 root      20   0   19768   2352   2024 S   0.0  0.0   0:00.00 interprete+
   80 root      20   0 4543116 105596  19000 S   0.0  0.6   0:05.98 java
  524 root      20   0   19768   3416   3096 S   0.0  0.0   0:00.00 interprete+
  536 root      20   0   19768   2360   2036 S   0.0  0.0   0:00.00 interprete+
  537 root      20   0 4559460 322272  18492 S   0.0  2.0   0:10.76 java
  593 root      20   0   19768   3416   3092 S   0.0  0.0   0:00.00 interprete+
  605 root      20   0   19768   2360   2032 S   0.0  0.0   0:00.00 interprete+
  606 root      20   0 4636276 154240  18508 S   0.0  0.9   0:08.24 java
  644 root      20   0 1418844  46480  14084 S   0.0  0.3   0:03.15 python
  676 root      20   0  605696  44584  11276 S   0.0  0.3   0:01.40 python
  744 root      20   0   19768   3416   3100 S   0.0  0.0   0:00.00 interprete+
  756 root      20   0   19768   2268   1948 S   0.0  0.0   0:00.00 interprete+
  757 root      20   0 4076040 148232  17972 S   0.0  0.9   0:06.95 java
 1148 root      20   0   38164   3404   3032 R   0.0  0.0   0:00.00 top

%python3
plt.ylim(0, 1)
plt.plot(acc_list)
plt.legend(labels=['AttentionSeq2Seq'])
plt.show()

  • メモリはそんなに消費しない
  • 2 epochで正解率100%に近づく
  • ec2: m4.xlargeで49分

Attention の可視化

%sh
chmod go+r AttentionSeq2Seq.pkl

%python3
model.load_params()

%python3
import matplotlib.pyplot as plt

_idx = 0
def visualize(attention_map, row_labels, column_labels):
    flg, ax = plt.subplots()
    ax.pcolor(attention_map, cmap=plt.cm.Greys_r, vmin=0.0, vmax=1.0)
    
    ax.patch.set_facecolor('black')
    ax.set_yticks(np.arange(attention_map.shape[0])+0.5, minor=False)
    ax.set_xticks(np.arange(attention_map.shape[1])+0.5, minor=False)
    ax.invert_yaxis()
    ax.set_xticklabels(row_labels, minor=False)
    ax.set_yticklabels(column_labels, minor=False)
    
    global _idx
    _idx += 1
    plt.show()

np.random.seed(1984)
for _ in range(5):
    idx = [np.random.randint(0, len(x_test))]
    x = x_test[idx]
    t = t_test[idx]
    
    model.forward(x, t)
    d = model.decoder.attention.attention_weights
    d = np.array(d)
    attention_map = d.reshape(d.shape[0], d.shape[2])
    
    attention_map = attention_map[:,::-1]
    x = x[:, ::-1]
    
    row_labels = [id_to_char[i] for i in x[0]]
    column_labels = [id_to_char[i] for i in t[0]]
    column_labels = column_labels[1:]
    
    visualize(attention_map, row_labels, column_labels)

plt.show()

%sh