ゼロから作る Deep Learning 2/Attention
Posted on
8章: Attention
ゼロから作るDeep Learning (2)の読書メモです。7章では単純な足し算の予測を通して時系列データを別の時系列データに変換するための seq2seq という手法について学びました。8章では seq2seq をさらに強力にする注意機構(attention mechanism)という仕組みを導入していきます。
TL;DR
- 前章で実装した seq2seq は Encoder の出力が一つの固定長ベクトル
- 入力サイズに対してスケールできず状態が溢れる
- Attention の導入
- Encoder の各時刻における LSTM の隠れ状態をまとめて Decoder に渡す
- 隠れ状態の個数に合わせて確率分布的な重要度を表す重みを持った数列を用意する
- 重み付き和を使ってコンテキストベクトルを算出
参考実装
%sh cd /tmp git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2
Cloning into 'deep-learning-from-scratch-2'...
python のバージョン
%sh python3 --version
Python 3.5.2
numpy と matplotlib を入れる
%sh pip3 install numpy matplotlib
bash: pip3: command not found
お
%sh apt install -y python3-pip
WARNING: apt does not have a stable CLI interface. Use with caution in scripts. Reading package lists... Building dependency tree... Reading state information... The following additional packages will be installed: libpython3-dev libpython3.5 libpython3.5-dev python3-dev python3-pkg-resources python3-setuptools python3-wheel python3.5-dev Suggested packages: python-setuptools-doc The following NEW packages will be installed: libpython3-dev libpython3.5 libpython3.5-dev python3-dev python3-pip python3-pkg-resources python3-setuptools python3-wheel python3.5-dev 0 upgraded, 9 newly installed, 0 to remove and 0 not upgraded. Need to get 39.4 MB of archives. After this operation, 60.3 MB of additional disk space will be used. Get:1 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 libpython3.5 amd64 3.5.2-2ubuntu0~16.04.5 [1,360 kB] Get:2 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 libpython3.5-dev amd64 3.5.2-2ubuntu0~16.04.5 [37.3 MB] Get:3 http://archive.ubuntu.com/ubuntu xenial/main amd64 libpython3-dev amd64 3.5.1-3 [6,926 B] Get:4 http://archive.ubuntu.com/ubuntu xenial-updates/main amd64 python3.5-dev amd64 3.5.2-2ubuntu0~16.04.5 [413 kB] Get:5 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-dev amd64 3.5.1-3 [1,186 B] Get:6 http://archive.ubuntu.com/ubuntu xenial-updates/universe amd64 python3-pip all 8.1.1-2ubuntu0.4 [109 kB] Get:7 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-pkg-resources all 20.7.0-1 [79.0 kB] Get:8 http://archive.ubuntu.com/ubuntu xenial/main amd64 python3-setuptools all 20.7.0-1 [88.0 kB] Get:9 http://archive.ubuntu.com/ubuntu xenial/universe amd64 python3-wheel all 0.29.0-1 [48.1 kB] debconf: unable to initialize frontend: Dialog debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 9.) debconf: falling back to frontend: Readline Fetched 39.4 MB in 2s (17.1 MB/s) Selecting previously unselected package libpython3.5:amd64. (Reading database ... (Reading database ... 5% (Reading database ... 10% (Reading database ... 15% (Reading database ... 20% (Reading database ... 25% (Reading database ... 30% (Reading database ... 35% (Reading database ... 40% (Reading database ... 45% (Reading database ... 50% (Reading database ... 55% (Reading database ... 60% (Reading database ... 65% (Reading database ... 70% (Reading database ... 75% (Reading database ... 80% (Reading database ... 85% (Reading database ... 90% (Reading database ... 95% (Reading database ... 100% (Reading database ... 32914 files and directories currently installed.) Preparing to unpack .../libpython3.5_3.5.2-2ubuntu0~16.04.5_amd64.deb ... Unpacking libpython3.5:amd64 (3.5.2-2ubuntu0~16.04.5) ... Selecting previously unselected package libpython3.5-dev:amd64. Preparing to unpack .../libpython3.5-dev_3.5.2-2ubuntu0~16.04.5_amd64.deb ... Unpacking libpython3.5-dev:amd64 (3.5.2-2ubuntu0~16.04.5) ... Selecting previously unselected package libpython3-dev:amd64. Preparing to unpack .../libpython3-dev_3.5.1-3_amd64.deb ... Unpacking libpython3-dev:amd64 (3.5.1-3) ... Selecting previously unselected package python3.5-dev. Preparing to unpack .../python3.5-dev_3.5.2-2ubuntu0~16.04.5_amd64.deb ... Unpacking python3.5-dev (3.5.2-2ubuntu0~16.04.5) ... Selecting previously unselected package python3-dev. Preparing to unpack .../python3-dev_3.5.1-3_amd64.deb ... Unpacking python3-dev (3.5.1-3) ... Selecting previously unselected package python3-pip. Preparing to unpack .../python3-pip_8.1.1-2ubuntu0.4_all.deb ... Unpacking python3-pip (8.1.1-2ubuntu0.4) ... Selecting previously unselected package python3-pkg-resources. Preparing to unpack .../python3-pkg-resources_20.7.0-1_all.deb ... Unpacking python3-pkg-resources (20.7.0-1) ... Selecting previously unselected package python3-setuptools. Preparing to unpack .../python3-setuptools_20.7.0-1_all.deb ... Unpacking python3-setuptools (20.7.0-1) ... Selecting previously unselected package python3-wheel. Preparing to unpack .../python3-wheel_0.29.0-1_all.deb ... Unpacking python3-wheel (0.29.0-1) ... Processing triggers for libc-bin (2.23-0ubuntu10) ... Processing triggers for man-db (2.7.5-1) ... Setting up libpython3.5:amd64 (3.5.2-2ubuntu0~16.04.5) ... Setting up libpython3.5-dev:amd64 (3.5.2-2ubuntu0~16.04.5) ... Setting up libpython3-dev:amd64 (3.5.1-3) ... Setting up python3.5-dev (3.5.2-2ubuntu0~16.04.5) ... Setting up python3-dev (3.5.1-3) ... Setting up python3-pip (8.1.1-2ubuntu0.4) ... Setting up python3-pkg-resources (20.7.0-1) ... Setting up python3-setuptools (20.7.0-1) ... Setting up python3-wheel (0.29.0-1) ... Processing triggers for libc-bin (2.23-0ubuntu10) ...
リトライ
%sh pip3 install numpy matplotlib
Collecting numpy Downloading https://files.pythonhosted.org/packages/ad/15/690c13ae714e156491392cdbdbf41b485d23c285aa698239a67f7cfc9e0a/numpy-1.16.1-cp35-cp35m-manylinux1_x86_64.whl (17.2MB) Collecting matplotlib Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB) Collecting kiwisolver>=1.0.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB) Collecting python-dateutil>=2.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB) Collecting cycler>=0.10 (from matplotlib) Downloading https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib) Downloading https://files.pythonhosted.org/packages/de/0a/001be530836743d8be6c2d85069f46fecf84ac6c18c7f5fb8125ee11d854/pyparsing-2.3.1-py2.py3-none-any.whl (61kB) Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib) Collecting six>=1.5 (from python-dateutil>=2.1->matplotlib) Downloading https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl Installing collected packages: numpy, kiwisolver, six, python-dateutil, cycler, pyparsing, matplotlib Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.16.1 pyparsing-2.3.1 python-dateutil-2.7.5 six-1.12.0 You are using pip version 8.1.1, however version 19.0.1 is available. You should consider upgrading via the 'pip install --upgrade pip' command.
WeightSum レイヤ
Encoder から受け取った情報に対して重みを掛けて和を取ったものを流していく
%python3 class WeightSum: def __init__(self): self.params, self.grads = [], [] self.cache = None def forward(self, hs, a): N, T, H = hs.shape ar = a.reshape(N, T, 1).repeat(H, axis=2) t = hs * ar c = np.sum(t, axis=1) self.cache = (hs, ar) return c def backward(self, dc): hs, ar = self.cache N, T, H = hs.shape dt = dc.reshape(N, 1, H).repeat(T, axis=1) dar = dt * hs dhs = dt * ar da = np.sum(dar, axis=2) return dhs, da
AttentionWeight レイヤ
Encoder から入力されたベクトルについて Decoder の LSTM レイヤの隠れ状態ベクトルとの類似度を内積で取って Softmax 関数を適用する
%python3 import sys sys.path.append('/tmp/deep-learning-from-scratch-2')
%python3 import numpy as np from common.layers import Softmax class AttentionWeight: def __init__(self): self.params, self.grads = [], [] self.softmax = Softmax() self.cache = None def forward(self, hs, h): N, T, H = hs.shape hr = h.reshape(N, 1, H).repeat(T, axis=1) t = hs * hr s = np.sum(t, axis=2) a = self.softmax.forward(s) self.cache = (hs, hr) return a def backward(self, da): hs, hr = self.cache N, T, H = hs.shape ds = self.softmax.backward(da) dt = ds.reshape(N, T, 1).repeat(H, axis=2) dhs = dt * hr dhr = dt * hs dh = np.sum(dhr, axis=1) return dhs, dh
Attention レイヤ
%python3 class Attention: def __init__(self): self.params, self.grads = [], [] self.attention_weight_layer = AttentionWeight() self.weight_sum_layer = WeightSum() self.attention_weight = None def forward(self, hs, h): a = self.attention_weight_layer.forward(hs, h) out = self.weight_sum_layer.forward(hs, a) self.attention_weight = a return out def backward(self, dout): dhs0, da = self.weight_sum_layer.backward(dout) dhs1, dh = self.attention_weight_layer.backward(da) dhs = dhs0 + dhs1 return dhs, dh
TimeAttention レイヤ
Encoder からの入力をそれぞれ受け取って計算
%python3 class TimeAttention: def __init__(self): self.params, self.grads = [], [] self.layers = None self.attention_weights = None def forward(self, hs_enc, hs_dec): N, T, H = hs_dec.shape out = np.empty_like(hs_dec) self.layers = [] self.attention_weights = [] for t in range(T): layer = Attention() out[:, t, :] = layer.forward(hs_enc, hs_dec[:, t, :]) self.layers.append(layer) self.attention_weights.append(layer.attention_weight) return out def backward(self, dout): N, T, H = dout.shape dhs_enc = 0 dhs_dec = np.empty_like(dout) for t in range(T): layer = self.layers[t] dhs, dh = layer.backward(dout[:, t, :]) dhs_enc += dhs dhs_dec[:, t, :] = dh return dhs_enc, dhs_dec
AttentionEncoder
前章の Encoder クラスとの変更点
- forward メソッドで最後の隠れ状態ベクトルだけを返していたところを、すべての隠れ状態を返すように変更
- backward メソッドも同様に修正する
%python3 class Encoder: def __init__(self, vocab_size, wordvec_size, hidden_size): V, D, H = vocab_size, wordvec_size, hidden_size rn = np.random.randn embed_W = (rn(V, D) / 100).astype('f') lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f') lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f') lstm_b = np.zeros(4 * H).astype('f') self.embed = TimeEmbedding(embed_W) self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=False) self.params = self.embed.params + self.lstm.params self.grads = self.embed.grads + self.lstm.grads self.hs = None def forward(self, xs): xs = self.embed.forward(xs) hs = self.lstm.forward(xs) self.hs = hs return hs[:, -1, :] def backward(self, dh): dhs = np.zeros_like(self.hs) dhs[:, -1, :] = dh dout = self.lstm.backward(dhs) dout = self.embed.backward(dout) return dout
%python3 from common.time_layers import * class AttentionEncoder(Encoder): def forward(self, xs): xs = self.embed.forward(xs) hs = self.lstm.forward(xs) return hs def backward(self, dhs): dout = self.lstm.backward(dhs) dout = self.embed.backward(dout) return dout
AttentionDecoder
%python3 class AttentionDecoder: def __init__(self, vocab_size, wordvec_size, hidden_size): V, D, H = vocab_size, wordvec_size, hidden_size rn = np.random.randn embed_W = (rn(V, D) / 100).astype('f') lstm_Wx = (rn(D, 4 * H) / np.sqrt(D)).astype('f') lstm_Wh = (rn(H, 4 * H) / np.sqrt(H)).astype('f') lstm_b = np.zeros(4 * H).astype('f') affine_W = (rn(2*H, V) / np.sqrt(2*H)).astype('f') affine_b = np.zeros(V).astype('f') self.embed = TimeEmbedding(embed_W) self.lstm = TimeLSTM(lstm_Wx, lstm_Wh, lstm_b, stateful=True) self.attention = TimeAttention() self.affine = TimeAffine(affine_W, affine_b) layers = [self.embed, self.lstm, self.attention, self.affine] self.params, self.grads = [], [] for layer in layers: self.params += layer.params self.grads += layer.grads def forward(self, xs, enc_hs): h = enc_hs[:, -1] self.lstm.set_state(h) out = self.embed.forward(xs) dec_hs = self.lstm.forward(out) c = self.attention.forward(enc_hs, dec_hs) out = np.concatenate((c, dec_hs), axis=2) score = self.affine.forward(out) return score def backward(self, dscore): dout = self.affine.backward(dscore) N, T, H2 = dout.shape H = H2 // 2 dc, ddec_hs0 = dout[:,:,:H], dout[:,:,H:] denc_hs, ddec_hs1 = self.attention.backward(dc) ddec_hs = ddec_hs0 + ddec_hs1 dout = self.lstm.backward(ddec_hs) dh = self.lstm.dh denc_hs[:, -1] += dh self.embed.backward(dout) return denc_hs def generate(self, enc_hs, start_id, sample_size): sampled = [] sample_id = start_id h = enc_hs[:, -1] self.lstm.set_state(h) for _ in range(sample_size): x = np.array([sample_id]).reshape((1, 1)) out = self.embed.forward(x) dec_hs = self.lstm.forward(out) c = self.attention.forward(enc_hs, dec_hs) out = np.concatenate((c, dec_hs), axis=2) score = self.affine.forward(out) sample_id = np.argmax(score.flatten()) sampled.append(sample_id) return sampled
AttentionSeq2Seq
Attention を使った Encoder と Decoder に差し替え
%python3 from common.base_model import BaseModel class Seq2Seq(BaseModel): def __init__(self, vocab_size, wordvec_size, hidden_size): V, D, H = vocab_size, wordvec_size, hidden_size self.encoder = Encoder(V, D, H) self.decoder = Decoder(V, D, H) self.softmax = TimeSoftmaxWithLoss() self.params = self.encoder.params + self.decoder.params self.grads = self.encoder.grads + self.decoder.grads def forward(self, xs, ts): decoder_xs, decoder_ts = ts[:, :-1], ts[:, 1:] h = self.encoder.forward(xs) score = self.decoder.forward(decoder_xs, h) loss = self.softmax.forward(score, decoder_ts) return loss def backward(self, dout=1): dout = self.softmax.backward(dout) dh = self.decoder.backward(dout) dout = self.encoder.backward(dh) return dout def generate(self, xs, start_id, sample_size): h = self.encoder.forward(xs) sampled = self.decoder.generate(h, start_id, sample_size) return sampled
%python3 class AttentionSeq2Seq(Seq2Seq): def __init__(self, vocab_size, wordvec_size, hidden_size): args = vocab_size, wordvec_size, hidden_size self.encoder = AttentionEncoder(*args) self.decoder = AttentionDecoder(*args) self.softmax = TimeSoftmaxWithLoss() self.params = self.encoder.params + self.decoder.params self.grads = self.encoder.grads + self.decoder.grads
Attention 付き Seq2Seq の学習
今回は足し算ではなく日付のフォーマット変換をやる
使用するデータセット
%sh cat /tmp/deep-learning-from-scratch-2/dataset/date.txt | head
september 27, 1994 _1994-09-27 August 19, 2003 _2003-08-19 2/10/93 _1993-02-10 10/31/90 _1990-10-31 TUESDAY, SEPTEMBER 25, 1984 _1984-09-25 JUN 17, 2013 _2013-06-17 april 3, 1996 _1996-04-03 October 24, 1974 _1974-10-24 AUGUST 11, 1986 _1986-08-11 February 16, 2015 _2015-02-16
%python3 from dataset import sequence from common.optimizer import Adam from common.trainer import Trainer from common.util import eval_seq2seq
%python3 # データの読み込み (x_train, t_train), (x_test, t_test) = sequence.load_data('date.txt') char_to_id, id_to_char = sequence.get_vocab() # 入力を反転 x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]
%python3 # ハイパーパラメータの設定 vocab_size = len(char_to_id) wordvec_size = 16 hidden_size = 256 batch_size = 128 max_epoch = 10 max_grad = 5.0
%python3 model = AttentionSeq2Seq(vocab_size, wordvec_size, hidden_size) optimizer = Adam() trainer = Trainer(model, optimizer)
トレーニング実行
%python3 acc_list = [] for epoch in range(max_epoch): trainer.fit(x_train, t_train, max_epoch=1, batch_size=batch_size, max_grad=max_grad) correct_num = 0 for i in range(len(x_test)): question, correct = x_test[[i]], t_test[[i]] verbose = i < 10 correct_num += eval_seq2seq(model, question, correct, id_to_char, verbose, is_reverse=True) acc = float(correct_num) / len(x_test) acc_list.append(acc) print('val acc %.3f%%' % (acc * 100)) model.save_params()
| epoch 1 | iter 1 / 351 | time 0[s] | loss 4.08 | epoch 1 | iter 21 / 351 | time 12[s] | loss 3.09 | epoch 1 | iter 41 / 351 | time 24[s] | loss 1.90 | epoch 1 | iter 61 / 351 | time 37[s] | loss 1.72 | epoch 1 | iter 81 / 351 | time 49[s] | loss 1.46 | epoch 1 | iter 101 / 351 | time 62[s] | loss 1.19 | epoch 1 | iter 121 / 351 | time 74[s] | loss 1.14 | epoch 1 | iter 141 / 351 | time 87[s] | loss 1.09 | epoch 1 | iter 161 / 351 | time 99[s] | loss 1.06 | epoch 1 | iter 181 / 351 | time 111[s] | loss 1.04 | epoch 1 | iter 201 / 351 | time 124[s] | loss 1.03 | epoch 1 | iter 221 / 351 | time 136[s] | loss 1.02 | epoch 1 | iter 241 / 351 | time 149[s] | loss 1.02 | epoch 1 | iter 261 / 351 | time 161[s] | loss 1.01 | epoch 1 | iter 281 / 351 | time 174[s] | loss 1.00 | epoch 1 | iter 301 / 351 | time 186[s] | loss 1.00 | epoch 1 | iter 321 / 351 | time 199[s] | loss 1.00 | epoch 1 | iter 341 / 351 | time 211[s] | loss 1.00 Q 10/15/94 T 1994-10-15 [91m☒[0m 1978-08-11 --- Q thursday, november 13, 2008 T 2008-11-13 [91m☒[0m 1978-08-11 --- Q Mar 25, 2003 T 2003-03-25 [91m☒[0m 1978-08-11 --- Q Tuesday, November 22, 2016 T 2016-11-22 [91m☒[0m 1978-08-11 --- Q Saturday, July 18, 1970 T 1970-07-18 [91m☒[0m 1978-08-11 --- Q october 6, 1992 T 1992-10-06 [91m☒[0m 1978-08-11 --- Q 8/23/08 T 2008-08-23 [91m☒[0m 1978-08-11 --- Q 8/30/07 T 2007-08-30 [91m☒[0m 1978-08-11 --- Q 10/28/13 T 2013-10-28 [91m☒[0m 1978-08-11 --- Q sunday, november 6, 2016 T 2016-11-06 [91m☒[0m 1978-08-11 --- val acc 0.000% | epoch 2 | iter 1 / 351 | time 0[s] | loss 1.00 | epoch 2 | iter 21 / 351 | time 13[s] | loss 1.00 | epoch 2 | iter 41 / 351 | time 25[s] | loss 0.99 | epoch 2 | iter 61 / 351 | time 38[s] | loss 0.99 | epoch 2 | iter 81 / 351 | time 50[s] | loss 0.99 | epoch 2 | iter 101 / 351 | time 63[s] | loss 0.99 | epoch 2 | iter 121 / 351 | time 75[s] | loss 0.99 | epoch 2 | iter 141 / 351 | time 88[s] | loss 0.98 | epoch 2 | iter 161 / 351 | time 100[s] | loss 0.98 | epoch 2 | iter 181 / 351 | time 113[s] | loss 0.97 | epoch 2 | iter 201 / 351 | time 126[s] | loss 0.95 | epoch 2 | iter 221 / 351 | time 138[s] | loss 0.94 | epoch 2 | iter 241 / 351 | time 151[s] | loss 0.90 | epoch 2 | iter 261 / 351 | time 163[s] | loss 0.83 | epoch 2 | iter 281 / 351 | time 176[s] | loss 0.74 | epoch 2 | iter 301 / 351 | time 188[s] | loss 0.66 | epoch 2 | iter 321 / 351 | time 201[s] | loss 0.58 | epoch 2 | iter 341 / 351 | time 213[s] | loss 0.46 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [91m☒[0m 2006-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [91m☒[0m 2007-08-09 --- Q 10/28/13 T 2013-10-28 [91m☒[0m 1983-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [91m☒[0m 2016-11-08 --- val acc 51.640% | epoch 3 | iter 1 / 351 | time 0[s] | loss 0.35 | epoch 3 | iter 21 / 351 | time 13[s] | loss 0.30 | epoch 3 | iter 41 / 351 | time 25[s] | loss 0.21 | epoch 3 | iter 61 / 351 | time 38[s] | loss 0.14 | epoch 3 | iter 81 / 351 | time 50[s] | loss 0.09 | epoch 3 | iter 101 / 351 | time 63[s] | loss 0.07 | epoch 3 | iter 121 / 351 | time 75[s] | loss 0.05 | epoch 3 | iter 141 / 351 | time 88[s] | loss 0.04 | epoch 3 | iter 161 / 351 | time 100[s] | loss 0.03 | epoch 3 | iter 181 / 351 | time 113[s] | loss 0.03 | epoch 3 | iter 201 / 351 | time 125[s] | loss 0.02 | epoch 3 | iter 221 / 351 | time 138[s] | loss 0.02 | epoch 3 | iter 241 / 351 | time 150[s] | loss 0.02 | epoch 3 | iter 261 / 351 | time 163[s] | loss 0.01 | epoch 3 | iter 281 / 351 | time 175[s] | loss 0.01 | epoch 3 | iter 301 / 351 | time 188[s] | loss 0.01 | epoch 3 | iter 321 / 351 | time 200[s] | loss 0.01 | epoch 3 | iter 341 / 351 | time 213[s] | loss 0.01 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 99.900% | epoch 4 | iter 1 / 351 | time 0[s] | loss 0.01 | epoch 4 | iter 21 / 351 | time 13[s] | loss 0.01 | epoch 4 | iter 41 / 351 | time 25[s] | loss 0.01 | epoch 4 | iter 61 / 351 | time 38[s] | loss 0.01 | epoch 4 | iter 81 / 351 | time 50[s] | loss 0.01 | epoch 4 | iter 101 / 351 | time 62[s] | loss 0.01 | epoch 4 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 4 | iter 141 / 351 | time 87[s] | loss 0.01 | epoch 4 | iter 161 / 351 | time 100[s] | loss 0.00 | epoch 4 | iter 181 / 351 | time 112[s] | loss 0.00 | epoch 4 | iter 201 / 351 | time 125[s] | loss 0.00 | epoch 4 | iter 221 / 351 | time 137[s] | loss 0.00 | epoch 4 | iter 241 / 351 | time 150[s] | loss 0.00 | epoch 4 | iter 261 / 351 | time 162[s] | loss 0.00 | epoch 4 | iter 281 / 351 | time 175[s] | loss 0.00 | epoch 4 | iter 301 / 351 | time 187[s] | loss 0.00 | epoch 4 | iter 321 / 351 | time 200[s] | loss 0.00 | epoch 4 | iter 341 / 351 | time 212[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 99.900% | epoch 5 | iter 1 / 351 | time 0[s] | loss 0.00 | epoch 5 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 5 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 5 | iter 61 / 351 | time 38[s] | loss 0.00 | epoch 5 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 5 | iter 101 / 351 | time 63[s] | loss 0.00 | epoch 5 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 5 | iter 141 / 351 | time 88[s] | loss 0.00 | epoch 5 | iter 161 / 351 | time 101[s] | loss 0.00 | epoch 5 | iter 181 / 351 | time 113[s] | loss 0.00 | epoch 5 | iter 201 / 351 | time 126[s] | loss 0.00 | epoch 5 | iter 221 / 351 | time 138[s] | loss 0.00 | epoch 5 | iter 241 / 351 | time 151[s] | loss 0.00 | epoch 5 | iter 261 / 351 | time 163[s] | loss 0.00 | epoch 5 | iter 281 / 351 | time 176[s] | loss 0.00 | epoch 5 | iter 301 / 351 | time 188[s] | loss 0.00 | epoch 5 | iter 321 / 351 | time 201[s] | loss 0.00 | epoch 5 | iter 341 / 351 | time 213[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 99.920% | epoch 6 | iter 1 / 351 | time 0[s] | loss 0.00 | epoch 6 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 6 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 6 | iter 61 / 351 | time 38[s] | loss 0.00 | epoch 6 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 6 | iter 101 / 351 | time 63[s] | loss 0.00 | epoch 6 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 6 | iter 141 / 351 | time 88[s] | loss 0.00 | epoch 6 | iter 161 / 351 | time 100[s] | loss 0.00 | epoch 6 | iter 181 / 351 | time 113[s] | loss 0.00 | epoch 6 | iter 201 / 351 | time 125[s] | loss 0.00 | epoch 6 | iter 221 / 351 | time 138[s] | loss 0.00 | epoch 6 | iter 241 / 351 | time 150[s] | loss 0.00 | epoch 6 | iter 261 / 351 | time 163[s] | loss 0.00 | epoch 6 | iter 281 / 351 | time 175[s] | loss 0.00 | epoch 6 | iter 301 / 351 | time 188[s] | loss 0.00 | epoch 6 | iter 321 / 351 | time 200[s] | loss 0.00 | epoch 6 | iter 341 / 351 | time 213[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 99.940% | epoch 7 | iter 1 / 351 | time 0[s] | loss 0.00 | epoch 7 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 7 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 7 | iter 61 / 351 | time 38[s] | loss 0.00 | epoch 7 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 7 | iter 101 / 351 | time 63[s] | loss 0.00 | epoch 7 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 7 | iter 141 / 351 | time 88[s] | loss 0.00 | epoch 7 | iter 161 / 351 | time 100[s] | loss 0.00 | epoch 7 | iter 181 / 351 | time 113[s] | loss 0.00 | epoch 7 | iter 201 / 351 | time 125[s] | loss 0.00 | epoch 7 | iter 221 / 351 | time 138[s] | loss 0.00 | epoch 7 | iter 241 / 351 | time 150[s] | loss 0.00 | epoch 7 | iter 261 / 351 | time 163[s] | loss 0.00 | epoch 7 | iter 281 / 351 | time 175[s] | loss 0.00 | epoch 7 | iter 301 / 351 | time 188[s] | loss 0.02 | epoch 7 | iter 321 / 351 | time 200[s] | loss 0.04 | epoch 7 | iter 341 / 351 | time 213[s] | loss 0.01 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 99.180% | epoch 8 | iter 1 / 351 | time 0[s] | loss 0.01 | epoch 8 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 8 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 8 | iter 61 / 351 | time 38[s] | loss 0.00 | epoch 8 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 8 | iter 101 / 351 | time 63[s] | loss 0.00 | epoch 8 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 8 | iter 141 / 351 | time 88[s] | loss 0.00 | epoch 8 | iter 161 / 351 | time 100[s] | loss 0.00 | epoch 8 | iter 181 / 351 | time 113[s] | loss 0.00 | epoch 8 | iter 201 / 351 | time 125[s] | loss 0.00 | epoch 8 | iter 221 / 351 | time 138[s] | loss 0.00 | epoch 8 | iter 241 / 351 | time 150[s] | loss 0.00 | epoch 8 | iter 261 / 351 | time 163[s] | loss 0.00 | epoch 8 | iter 281 / 351 | time 175[s] | loss 0.00 | epoch 8 | iter 301 / 351 | time 188[s] | loss 0.00 | epoch 8 | iter 321 / 351 | time 201[s] | loss 0.00 | epoch 8 | iter 341 / 351 | time 213[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 100.000% | epoch 9 | iter 1 / 351 | time 0[s] | loss 0.00 | epoch 9 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 9 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 9 | iter 61 / 351 | time 37[s] | loss 0.00 | epoch 9 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 9 | iter 101 / 351 | time 62[s] | loss 0.00 | epoch 9 | iter 121 / 351 | time 75[s] | loss 0.00 | epoch 9 | iter 141 / 351 | time 87[s] | loss 0.00 | epoch 9 | iter 161 / 351 | time 100[s] | loss 0.00 | epoch 9 | iter 181 / 351 | time 112[s] | loss 0.00 | epoch 9 | iter 201 / 351 | time 124[s] | loss 0.00 | epoch 9 | iter 221 / 351 | time 137[s] | loss 0.00 | epoch 9 | iter 241 / 351 | time 149[s] | loss 0.00 | epoch 9 | iter 261 / 351 | time 162[s] | loss 0.00 | epoch 9 | iter 281 / 351 | time 174[s] | loss 0.00 | epoch 9 | iter 301 / 351 | time 187[s] | loss 0.00 | epoch 9 | iter 321 / 351 | time 199[s] | loss 0.00 | epoch 9 | iter 341 / 351 | time 212[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 100.000% | epoch 10 | iter 1 / 351 | time 0[s] | loss 0.00 | epoch 10 | iter 21 / 351 | time 13[s] | loss 0.00 | epoch 10 | iter 41 / 351 | time 25[s] | loss 0.00 | epoch 10 | iter 61 / 351 | time 38[s] | loss 0.00 | epoch 10 | iter 81 / 351 | time 50[s] | loss 0.00 | epoch 10 | iter 101 / 351 | time 63[s] | loss 0.00 | epoch 10 | iter 121 / 351 | time 76[s] | loss 0.00 | epoch 10 | iter 141 / 351 | time 88[s] | loss 0.00 | epoch 10 | iter 161 / 351 | time 101[s] | loss 0.00 | epoch 10 | iter 181 / 351 | time 114[s] | loss 0.00 | epoch 10 | iter 201 / 351 | time 126[s] | loss 0.00 | epoch 10 | iter 221 / 351 | time 139[s] | loss 0.00 | epoch 10 | iter 241 / 351 | time 151[s] | loss 0.00 | epoch 10 | iter 261 / 351 | time 164[s] | loss 0.00 | epoch 10 | iter 281 / 351 | time 176[s] | loss 0.00 | epoch 10 | iter 301 / 351 | time 189[s] | loss 0.00 | epoch 10 | iter 321 / 351 | time 202[s] | loss 0.00 | epoch 10 | iter 341 / 351 | time 214[s] | loss 0.00 Q 10/15/94 T 1994-10-15 [92m☑[0m 1994-10-15 --- Q thursday, november 13, 2008 T 2008-11-13 [92m☑[0m 2008-11-13 --- Q Mar 25, 2003 T 2003-03-25 [92m☑[0m 2003-03-25 --- Q Tuesday, November 22, 2016 T 2016-11-22 [92m☑[0m 2016-11-22 --- Q Saturday, July 18, 1970 T 1970-07-18 [92m☑[0m 1970-07-18 --- Q october 6, 1992 T 1992-10-06 [92m☑[0m 1992-10-06 --- Q 8/23/08 T 2008-08-23 [92m☑[0m 2008-08-23 --- Q 8/30/07 T 2007-08-30 [92m☑[0m 2007-08-30 --- Q 10/28/13 T 2013-10-28 [92m☑[0m 2013-10-28 --- Q sunday, november 6, 2016 T 2016-11-06 [92m☑[0m 2016-11-06 --- val acc 100.000%
%sh top -b -n 1
top - 07:51:33 up 53 min, 0 users, load average: 3.95, 2.30, 0.97 Tasks: 18 total, 2 running, 16 sleeping, 0 stopped, 0 zombie %Cpu(s): 4.4 us, 4.1 sy, 0.0 ni, 90.4 id, 0.7 wa, 0.0 hi, 0.0 si, 0.3 st KiB Mem : 16426336 total, 13631944 free, 1586288 used, 1208104 buff/cache KiB Swap: 0 total, 0 free, 0 used. 14425956 avail Mem PID USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND 573 root 20 0 596620 255420 17564 R 373.3 1.6 16:34.76 python3 8 root 20 0 4659904 612268 25060 S 26.7 3.7 0:59.91 java 1 root 20 0 4364 736 672 S 0.0 0.0 0:00.07 tini 67 root 20 0 19768 3408 3084 S 0.0 0.0 0:00.00 interprete+ 79 root 20 0 19768 2352 2024 S 0.0 0.0 0:00.00 interprete+ 80 root 20 0 4542088 100484 19000 S 0.0 0.6 0:03.89 java 524 root 20 0 19768 3416 3096 S 0.0 0.0 0:00.00 interprete+ 536 root 20 0 19768 2360 2036 S 0.0 0.0 0:00.00 interprete+ 537 root 20 0 4559460 126160 18468 S 0.0 0.8 0:04.89 java 593 root 20 0 19768 3416 3092 S 0.0 0.0 0:00.00 interprete+ 605 root 20 0 19768 2360 2032 S 0.0 0.0 0:00.00 interprete+ 606 root 20 0 4636276 151476 18508 S 0.0 0.9 0:05.36 java 644 root 20 0 1418844 46480 14084 S 0.0 0.3 0:02.03 python 676 root 20 0 605696 44584 11276 S 0.0 0.3 0:01.12 python 744 root 20 0 19768 3416 3100 S 0.0 0.0 0:00.00 interprete+ 756 root 20 0 19768 2268 1948 S 0.0 0.0 0:00.00 interprete+ 757 root 20 0 4076040 147440 17972 S 0.0 0.9 0:05.05 java 1130 root 20 0 38164 3480 3108 R 0.0 0.0 0:00.00 top
%sh top -b -n 1
top - 08:32:39 up 1:34, 0 users, load average: 4.04, 4.03, 3.86 Tasks: 18 total, 2 running, 16 sleeping, 0 stopped, 0 zombie %Cpu(s): 23.2 us, 24.7 sy, 0.0 ni, 51.5 id, 0.4 wa, 0.0 hi, 0.0 si, 0.2 st KiB Mem : 16426336 total, 13396032 free, 1821236 used, 1209068 buff/cache KiB Swap: 0 total, 0 free, 0 used. 14190540 avail Mem PID USER PR NI VIRT RES SHR S %CPU %MEM TIME+ COMMAND 573 root 20 0 625768 284436 17628 R 373.3 1.7 179:43.56 python3 8 root 20 0 4664008 613632 25060 S 6.7 3.7 1:08.96 java 1 root 20 0 4364 736 672 S 0.0 0.0 0:00.12 tini 67 root 20 0 19768 3408 3084 S 0.0 0.0 0:00.00 interprete+ 79 root 20 0 19768 2352 2024 S 0.0 0.0 0:00.00 interprete+ 80 root 20 0 4543116 105596 19000 S 0.0 0.6 0:05.98 java 524 root 20 0 19768 3416 3096 S 0.0 0.0 0:00.00 interprete+ 536 root 20 0 19768 2360 2036 S 0.0 0.0 0:00.00 interprete+ 537 root 20 0 4559460 322272 18492 S 0.0 2.0 0:10.76 java 593 root 20 0 19768 3416 3092 S 0.0 0.0 0:00.00 interprete+ 605 root 20 0 19768 2360 2032 S 0.0 0.0 0:00.00 interprete+ 606 root 20 0 4636276 154240 18508 S 0.0 0.9 0:08.24 java 644 root 20 0 1418844 46480 14084 S 0.0 0.3 0:03.15 python 676 root 20 0 605696 44584 11276 S 0.0 0.3 0:01.40 python 744 root 20 0 19768 3416 3100 S 0.0 0.0 0:00.00 interprete+ 756 root 20 0 19768 2268 1948 S 0.0 0.0 0:00.00 interprete+ 757 root 20 0 4076040 148232 17972 S 0.0 0.9 0:06.95 java 1148 root 20 0 38164 3404 3032 R 0.0 0.0 0:00.00 top
%python3 plt.ylim(0, 1) plt.plot(acc_list) plt.legend(labels=['AttentionSeq2Seq']) plt.show()
- メモリはそんなに消費しない
- 2 epochで正解率100%に近づく
- ec2: m4.xlargeで49分
Attention の可視化
%sh chmod go+r AttentionSeq2Seq.pkl
%python3 model.load_params()
%python3 import matplotlib.pyplot as plt _idx = 0 def visualize(attention_map, row_labels, column_labels): flg, ax = plt.subplots() ax.pcolor(attention_map, cmap=plt.cm.Greys_r, vmin=0.0, vmax=1.0) ax.patch.set_facecolor('black') ax.set_yticks(np.arange(attention_map.shape[0])+0.5, minor=False) ax.set_xticks(np.arange(attention_map.shape[1])+0.5, minor=False) ax.invert_yaxis() ax.set_xticklabels(row_labels, minor=False) ax.set_yticklabels(column_labels, minor=False) global _idx _idx += 1 plt.show() np.random.seed(1984) for _ in range(5): idx = [np.random.randint(0, len(x_test))] x = x_test[idx] t = t_test[idx] model.forward(x, t) d = model.decoder.attention.attention_weights d = np.array(d) attention_map = d.reshape(d.shape[0], d.shape[2]) attention_map = attention_map[:,::-1] x = x[:, ::-1] row_labels = [id_to_char[i] for i in x[0]] column_labels = [id_to_char[i] for i in t[0]] column_labels = column_labels[1:] visualize(attention_map, row_labels, column_labels) plt.show()
%sh