ゼロから作る Deep Learning 2/RNN

Posted on

5章 - リカレントニューラルネットワーク(RNN)

前章では CBOW モデルを利用して Word2Vec を実装しました。また Negative Sampling という手法を用いて多値分類を二値分類へと近似するアイデアを導入することで大きなコーパスを処理することができました。この章では時系列データなどを扱うことができる Recurrent Neural Network (RNN)を実装していきます。

%sh
git clone https://github.com/oreilly-japan/deep-learning-from-scratch-2 /tmp/deep-learning-from-scratch-2
Cloning into '/tmp/deep-learning-from-scratch-2'...

%sh
pip3 install matplotlib numpy pandas
Collecting matplotlib
  Downloading https://files.pythonhosted.org/packages/ad/4c/0415f15f96864c3a2242b1c74041a806c100c1b21741206c5d87684437c6/matplotlib-3.0.2-cp35-cp35m-manylinux1_x86_64.whl (12.9MB)
Collecting numpy
  Downloading https://files.pythonhosted.org/packages/86/04/bd774106ae0ae1ada68c67efe89f1a16b2aa373cc2db15d974002a9f136d/numpy-1.15.4-cp35-cp35m-manylinux1_x86_64.whl (13.8MB)
Collecting pandas
  Downloading https://files.pythonhosted.org/packages/5d/d4/6e9c56a561f1d27407bf29318ca43f36ccaa289271b805a30034eb3a8ec4/pandas-0.23.4-cp35-cp35m-manylinux1_x86_64.whl (8.7MB)
Collecting kiwisolver>=1.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/7e/31/d6fedd4fb2c94755cd101191e581af30e1650ccce7a35bddb7930fed6574/kiwisolver-1.0.1-cp35-cp35m-manylinux1_x86_64.whl (949kB)
Collecting python-dateutil>=2.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/74/68/d87d9b36af36f44254a8d512cbfc48369103a3b9e474be9bdfe536abfc45/python_dateutil-2.7.5-py2.py3-none-any.whl (225kB)
Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 (from matplotlib)
  Downloading https://files.pythonhosted.org/packages/71/e8/6777f6624681c8b9701a8a0a5654f3eb56919a01a78e12bf3c73f5a3c714/pyparsing-2.3.0-py2.py3-none-any.whl (59kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached https://files.pythonhosted.org/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting pytz>=2011k (from pandas)
  Downloading https://files.pythonhosted.org/packages/f8/0e/2365ddc010afb3d79147f1dd544e5ee24bf4ece58ab99b16fbb465ce6dc0/pytz-2018.7-py2.py3-none-any.whl (506kB)
Requirement already satisfied (use --upgrade to upgrade): setuptools in /usr/lib/python3/dist-packages (from kiwisolver>=1.0.1->matplotlib)
Collecting six>=1.5 (from python-dateutil>=2.1->matplotlib)
  Downloading https://files.pythonhosted.org/packages/67/4b/141a581104b1f6397bfa78ac9d43d8ad29a7ca43ea90a2d863fe3056e86a/six-1.11.0-py2.py3-none-any.whl
Installing collected packages: kiwisolver, numpy, six, python-dateutil, pyparsing, cycler, matplotlib, pytz, pandas
Successfully installed cycler-0.10.0 kiwisolver-1.0.1 matplotlib-3.0.2 numpy-1.15.4 pandas-0.23.4 pyparsing-2.3.0 python-dateutil-2.7.5 pytz-2018.7 six-1.11.0
You are using pip version 8.1.1, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

5.1.2 - 言語モデル

  • LM; Language Model
  • 単語の並びに対して確率を与える
    • you say goodbye => 0.092
    • you say good die => 0.0000000000000032
    • 単語の並びがどれだけ自然であるかを確率で評価する
  • 同時確率
    • 複数の事象が同時に起こる確率

$$
\begin{eqnarray}
P(w_1,\ \dots,\ w_m) &=& P(w_m|w_1,\ \dots,\ w_{m-1}) P(w_{m-1}|w_1,\ \dots,\ w_{m-2}) \dots P(w_{3}|w_1, w_2) P(w_2|w_1) P(w_1) \\
&=& \prod^m_{t=1}P(w_t|w_1,\ \dots,\ w_{t-1})
\end{eqnarray}
$$

  • \(P(x|a,\ b)\) は事象 a と b が起こった後に事象 x が起きる確率(事後確率)
  • 言語モデルの確率は乗法定理から
  • 乗法定理

$$
P(A, B) = P(A|B)P(B)
$$

  • こうして

$$
P(\underbrace{w_1, \dots, w_{m-1}}_A, w_m) = P(A,w_m) = P(w_m|A)P(A)
$$

  • こうするのを繰り返す

$$
P(A) = P(\underbrace{w_1, \dots, w_{m-1}}_{A’}, w_{m-1}) = P(A’, w_{m-1}) = P(w_{m-1}, A’)P(A’)
$$

5.2.2 - ループの展開

同じレイヤーを繰り返し利用するところが word2vec 編で組んだニューラルネットワークと異なる

  • 出力は次の数式で表される

$$
h_t = tanh(h_{t-1} W_h + x_t W_x + b)
$$

  • 入力を変換する部分と一つ前の RNN の出力を変換する部分とバイアス

5.2.3 - Backpropagation Through Time

  • BPTT
  • 時間方向に展開した誤差逆伝播法
  • Truncated BPTT
    • 大きな時系列データを扱うときはネットワークのつながりを適当な長さで断ち切る必要がある
    • 順伝播はそのまま
    • 逆伝播を一定のサイズで切り分けて切り分けた部分単位で学習を行う

5.3.1 - RNNレイヤの実装

  • 以下の数式でグラフをつくって微分

$$
h_t = tanh(h_{t-1}W_h + x_t W_x + b)
$$

%python3
class RNN:
  def __init__(self, Wx, Wh, b):
    self.params = [Wx, Wh, b]
    self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
    self.cache = None
    
  def forward(self, x, h_prev):
    Wx, Wh, b = self.params
    t = np.dot(h_prev, Wh) + np.dot(x, Wx) +b
    h_next = np.tanh(t)
    
    self.cache = (x, h_prev, h_next)
    return h_next
 
  def backward(self, dh_next):
    Wx, Wh, b = self.params
    x, h_prev, h_next = self.cache
    
    dt = dh_next + (1 - h_next ** 2)
    db = np.sum(dt, axis=0)
    dWh = np.dot(h_prev.T, dt)
    dh_prev = np.dot(dt, Wh.T)
    dWx = np.dot(x.T, dt)
    dx = np.dot(dt, Wx.T)
    
    self.grads[0][...] = dWx
    self.grads[1][...] = dWh
    self.grads[2][...] = db
    
    return dx, dh_prev

5.3.2 - TimeRNNレイヤの実装

%python3
class TimeRNN:
  def __init__(self, Wx, Wh, b, stateful=False):
    self.params = [Wx, Wh, b]
    self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
    self.layers = None
    
    self.h, self.dh = None, None
    self.stateful = stateful
  
  def set_state(self, h):
    self.h = h
   
  def reset_state(self):
    self.h = None
    
  def forward(self, xs):
    Wx, Wh, b = self.params
    N, T, D = xs.shape
    D, H = Wx.shape
    
    self.layers = []
    hs = np.empty((N, T, H), dtype='f')
    
    if not self.stateful or self.h is None:
      self.h = np.zeros((N, H), dtype='f')
     
    for t in range(T):
      layer = RNN(*self.params)
      self.h = layer.forward(xs[:, t, :], self.h)
      hs[:, t, :] = self.h
      self.layers.append(layer)
    
    return hs

  def backward(self, dhs):
    Wx, Wh, b = self.params
    N, T, H = dhs.shape
    D, H = Wx.shape
    
    dxs = np.empty((N, T, D), dtype='f')
    dh = 0
    grads = [0, 0, 0]
    
    for t in reversed(range(T)):
      layer = self.layers[t]
      dx, dh = layer.backward(dhs[:, t, :] + dh)
      dxs[:, t, :] = dx
      
      for i, grad in enumerate(layer.grads):
        grads[i] += grad
    
    for i, grad in enumerate(grads):
      self.grads[i][...] = grad
    
    self.dh = dh
    
    return dxs

  • TimeRNN; T個のRNNをつないだもの
  • 長い時系列データを処理するときは隠れ状態を維持する必要がある
  • Embedding => RNN => Affine => Softmax

5.5.1 - RNNLMの実装

%python3
import sys
sys.path.append('/tmp/deep-learning-from-scratch-2')
import numpy as np
from common.time_layers import *

class SimpleRnnlm:
  def __init__(self, vocab_size, wordvec_size, hidden_size):
    V, D, H = vocab_size, wordvec_size, hidden_size
    rn = np.random.randn
    
    # weight
    embed_W = (rn(V, D) / 100).astype('f')
    rnn_Wx = (rn(D, H) / np.sqrt(D)).astype('f')
    rnn_Wh = (rn(H, H) / np.sqrt(H)).astype('f')
    rnn_b = np.zeros(H).astype('f')
    affine_W = (rn(H, V) / np.sqrt(H)).astype('f')
    affine_b = np.zeros(V).astype('f')
    
    # layer
    self.layers = [
        TimeEmbedding(embed_W),
        TimeRNN(rnn_Wx, rnn_Wh, rnn_b, stateful=True),
        TimeAffine(affine_W, affine_b)
    ]
    self.loss_layer = TimeSoftmaxWithLoss()
    self.rnn_layer = self.layers[1]
    
    self.params, self.grads = [], []
    for layer in self.layers:
      self.params += layer.params
      self.grads += layer.grads
  
  def forward(self, xs, ts):
    for layer in self.layers:
      xs = layer.forward(xs)
    loss = self.loss_layer.forward(xs, ts)
    return loss
  
  def backward(self, dout=1):
    dout = self.loss_layer.backward(dout)
    for layer in reversed(self.layers):
      dout = layer.backward(dout)
    return dout
 
  def reset_state(self):
    self.rnn_layer.reset_state()

  • 重みは標準偏差をもつ分布に初期化

データセットの準備

%python3
# RNNLM
import sys
sys.path.append('deep-learning-from-scratch-2')
import matplotlib.pyplot as plt
import numpy as np
from common.optimizer import SGD
from dataset import ptb

# parameters
batch_size = 10
wordvec_size = 100
hidden_size = 100
time_size = 5
lr = 0.1
max_epoch = 100

# load dataset
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_size = 1000
corpus = corpus[:corpus_size]
vocab_size = int(max(corpus) + 1)

xs = corpus[:-1]
ts = corpus[1:]
data_size = len(xs)
print('corpus size: %d, vocabularry size: %d' % (corpus_size, vocab_size))
Downloading ptb.train.txt ... 
Done
corpus size: 1000, vocabularry size: 418

トレーニング実行

%python3
# variables
max_iters = data_size // (batch_size * time_size)
time_idx = 0
total_loss = 0
loss_count = 0
ppl_list = []

# model
model = SimpleRnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)


jump = (corpus_size - 1) // batch_size
offsets = [i * jump for i in range(batch_size)]

for epoch in range(max_epoch):
  for iter in range(max_iters):
    batch_x = np.empty((batch_size, time_size), dtype='i')
    batch_t = np.empty((batch_size, time_size), dtype='i')
    for t in range(time_size):
      for i, offset in enumerate(offsets):
        batch_x[i, t] = xs[(offset + time_idx) % data_size]
        batch_t[i, t] = ts[(offset + time_idx) % data_size]
      time_idx += 1
    
    loss = model.forward(batch_x, batch_t)
    model.backward()
    
    optimizer.update(model.params, model.grads)
    total_loss += loss
    loss_count += 1
  
  ppl = np.exp(total_loss / loss_count)
  print('| epoch %d | perplexity %.2f' % (epoch+1, ppl))
  ppl_list.append(float(ppl))
  total_loss, loss_count = 0, 0
| epoch 1 | perplexity 382.86
| epoch 2 | perplexity 254.91
| epoch 3 | perplexity 222.65
| epoch 4 | perplexity 213.66
| epoch 5 | perplexity 203.90
| epoch 6 | perplexity 200.80
| epoch 7 | perplexity 197.62
| epoch 8 | perplexity 195.92
| epoch 9 | perplexity 190.47
| epoch 10 | perplexity 191.88
| epoch 11 | perplexity 187.74
| epoch 12 | perplexity 191.16
| epoch 13 | perplexity 188.79
| epoch 14 | perplexity 189.19
| epoch 15 | perplexity 188.00
| epoch 16 | perplexity 184.00
| epoch 17 | perplexity 182.54
| epoch 18 | perplexity 179.67
| epoch 19 | perplexity 179.78
| epoch 20 | perplexity 180.40
| epoch 21 | perplexity 178.52
| epoch 22 | perplexity 174.72
| epoch 23 | perplexity 171.37
| epoch 24 | perplexity 171.24
| epoch 25 | perplexity 169.41
| epoch 26 | perplexity 168.92
| epoch 27 | perplexity 163.26
| epoch 28 | perplexity 161.68
| epoch 29 | perplexity 158.17
| epoch 30 | perplexity 152.58
| epoch 31 | perplexity 152.60
| epoch 32 | perplexity 147.84
| epoch 33 | perplexity 146.38
| epoch 34 | perplexity 141.89
| epoch 35 | perplexity 141.89
| epoch 36 | perplexity 135.98
| epoch 37 | perplexity 133.04
| epoch 38 | perplexity 126.65
| epoch 39 | perplexity 121.99
| epoch 40 | perplexity 117.50
| epoch 41 | perplexity 118.78
| epoch 42 | perplexity 111.64
| epoch 43 | perplexity 106.22
| epoch 44 | perplexity 101.10
| epoch 45 | perplexity 98.20
| epoch 46 | perplexity 96.53
| epoch 47 | perplexity 92.08
| epoch 48 | perplexity 87.16
| epoch 49 | perplexity 81.94
| epoch 50 | perplexity 79.14
| epoch 51 | perplexity 76.60
| epoch 52 | perplexity 73.37
| epoch 53 | perplexity 68.65
| epoch 54 | perplexity 65.25
| epoch 55 | perplexity 62.70
| epoch 56 | perplexity 58.56
| epoch 57 | perplexity 56.49
| epoch 58 | perplexity 51.97
| epoch 59 | perplexity 48.67
| epoch 60 | perplexity 46.76
| epoch 61 | perplexity 46.75
| epoch 62 | perplexity 43.02
| epoch 63 | perplexity 39.14
| epoch 64 | perplexity 37.46
| epoch 65 | perplexity 36.44
| epoch 66 | perplexity 33.77
| epoch 67 | perplexity 32.85
| epoch 68 | perplexity 30.99
| epoch 69 | perplexity 28.49
| epoch 70 | perplexity 27.31
| epoch 71 | perplexity 25.99
| epoch 72 | perplexity 24.15
| epoch 73 | perplexity 22.66
| epoch 74 | perplexity 21.70
| epoch 75 | perplexity 20.86
| epoch 76 | perplexity 19.16
| epoch 77 | perplexity 17.79
| epoch 78 | perplexity 16.40
| epoch 79 | perplexity 15.65
| epoch 80 | perplexity 14.65
| epoch 81 | perplexity 14.83
| epoch 82 | perplexity 14.02
| epoch 83 | perplexity 12.82
| epoch 84 | perplexity 12.09
| epoch 85 | perplexity 11.14
| epoch 86 | perplexity 11.09
| epoch 87 | perplexity 10.64
| epoch 88 | perplexity 10.05
| epoch 89 | perplexity 9.22
| epoch 90 | perplexity 8.67
| epoch 91 | perplexity 8.89
| epoch 92 | perplexity 8.27
| epoch 93 | perplexity 7.99
| epoch 94 | perplexity 7.18
| epoch 95 | perplexity 6.80
| epoch 96 | perplexity 6.47
| epoch 97 | perplexity 6.24
| epoch 98 | perplexity 6.16
| epoch 99 | perplexity 5.53
| epoch 100 | perplexity 5.31

  • perplexity が減る

%python3
model.rnn_layer.h[0]
array([ 0.99481535,  0.84290826, -0.9133508 , -0.7206012 ,  0.10123568,
        0.56462896,  0.4276802 ,  0.6728324 , -0.89788616, -0.36290118,
        0.93407196, -0.7388965 , -0.02504405,  0.59688294,  0.99928   ,
        0.94263035,  0.42114997, -0.5341952 ,  0.96084946,  0.8601841 ,
       -0.72191596,  0.9779833 , -0.5641036 , -0.94421744, -0.55707437,
       -0.7154803 , -0.3369527 , -0.90063494,  0.36353034, -0.6829183 ,
        0.74614024,  0.8568548 , -0.88780123,  0.96129876, -0.41962677,
        0.15925045, -0.10753727,  0.98698086,  0.950758  , -0.16255635,
        0.5273536 , -0.06916112,  0.30672634, -0.8988131 , -0.9447174 ,
        0.6972682 , -0.6615945 ,  0.17318887,  0.8872323 ,  0.1876182 ,
        0.9102601 ,  0.72029966, -0.04973648, -0.99363595, -0.9881019 ,
        0.76433754, -0.9236933 ,  0.81884986, -0.6499983 , -0.94718146,
        0.6227717 , -0.9419574 , -0.454079  ,  0.9009093 ,  0.94792247,
       -0.8391417 ,  0.9766069 , -0.2950728 ,  0.5510518 , -0.83473146,
        0.1847856 ,  0.8578545 ,  0.39274502, -0.8276459 , -0.6760022 ,
       -0.4585271 ,  0.68668073, -0.42346096, -0.9791731 , -0.27119333,
       -0.33470106, -0.00705867,  0.61266047,  0.26273778,  0.78953934,
       -0.8580364 , -0.9747648 ,  0.6128933 ,  0.9650971 ,  0.91345453,
       -0.8091221 , -0.9471533 ,  0.7015872 , -0.5921406 , -0.3563244 ,
        0.27482057,  0.5106253 ,  0.5722379 , -0.98263913, -0.30361378],
      dtype=float32)

  • なるほど forward して backward した後の TimeRNN の出力をみると予測値が取れそう

%md