Computational Graph

  •   2019-08-23(金)
  •  

計算グラフ(Computational Graph)

  • 計算の過程をグラフによって表現したもの。
    • ノード:オペコード
    • エッジ:オペランド
  • 途中の計算結果を全て保持できる。
  • 局所的な計算を伝播することによって最終的な計算結果を得ることが可能であり、それによって問題を単純化して理解できる。
  • 特に微分の計算は「計算の局所化」=「連鎖律(chain rule)」であり、相性が良い。

計算グラフの逆伝播

連鎖律と計算グラフ

連鎖律(chain rule)

ここでは例として $z=(x+y)^2$ という式を扱う。この式は、次の $2$ 式から構成されていると考えることができる。

$$ \begin{cases} z = t^2\\ t = x+y \end{cases} $$

この時、連鎖律から「$x$ に関する $z$ の微分 $\frac{\partial z}{\partial x}$」「$t$ に関する $z$ の微分 $\frac{\partial z}{\partial t}$」「$x$ に関する $t$ の微分 $\frac{\partial t}{\partial x}$」の積によって表すことができるので、以下のように書くことができる。

$$ \frac{\partial z}{\partial x} = \frac{\partial z}{\partial t} \times \frac{\partial t}{\partial x} $$

計算グラフ

上で数式的に考えた連鎖律の計算を計算グラフで表すと、以下のようになる。

連鎖律

これによって、局所的な微分計算の積によって合成関数の微分が行われていることが視覚的に理解できる。

色々なノードにおける逆伝播

ここから、様々なオペコードにおける逆伝播を考える。

加算ノード

$z=x+y$ という数式を考えと、この式の微分は次のように計算できる。

$$ \begin{cases} \begin{aligned} \frac{\partial z}{\partial x} = 1\\ \frac{\partial z}{\partial y} = 1\\ \end{aligned} \end{cases} $$

したがって、加算ノードの逆伝播は以下のように表される。

加算ノード

つまり、加算ノードの逆伝播は「上流から伝わった微分の値をそのまま次のノードへと流すだけ」になる。

乗算ノード

$z=xy$ という数式を考えと、この式の微分は次のように計算できる。

$$ \begin{cases} \begin{aligned} \frac{\partial z}{\partial x} = y\\ \frac{\partial z}{\partial y} = x\\ \end{aligned} \end{cases} $$

したがって、乗算ノードの逆伝播は以下のように表される。

乗算ノード

つまり、乗算ノードの逆伝播は「上流から伝わった微分の値に、順伝播の際の入力信号を『ひっくり返した値』を乗算して下流のノードへと流す」ことになる。

逆数ノード

$y=\frac{1}{x}$ という数式を考えと、この式の微分は次のように計算できる。

$$ \frac{\partial y}{\partial x} = -\frac{1}{x^2} = -y^2 $$

したがって、逆数ノードの逆伝播は以下のように表される。

逆数ノード

つまり、逆数ノードの逆伝播は「上流から伝わった微分の値に『順伝播の出力の二乗にマイナスを付けた値』を乗算して下流のノードへと流す」ことになる。

expノード

$y=\exp(x)$ という数式を考えと、この式の微分は次のように計算できる。

$$ \frac{\partial y}{\partial x} = \exp\left(x\right) = y $$

したがって、expノードの逆伝播は以下のように表される。

expノード

つまり、逆数ノードの逆伝播は「上流から伝わった微分の値に『順伝播の出力』を乗算して下流のノードへと流す」ことになる。

 ここまでのノードを用いることで、SoftmaxCross Entropy Errorの逆伝播も計算することができる。 Softmax-with-Loss

dotノード

$\mathbf{Y} = \mathbf{X}\cdot\mathbf{W} + \mathbf{B}$ という数式を考えると、この式の微分は以下のようになる。

$$ \begin{cases} \begin{aligned} \frac{\partial L}{\partial \mathbf{X}} = \frac{\partial L}{\partial \mathbf{X}}\cdot \mathbf{W}^T\\ \frac{\partial L}{\partial \mathbf{W}} = \mathbf{X}^T\cdot\frac{\partial L}{\partial \mathbf{Y}}\\ \end{aligned} \end{cases} $$

したがって、dotノードの逆伝播は以下のように表される。

dotノード

天下り式にこの逆伝播を理解しても良いが、行列を対象とした逆伝播を求める場合は、行列の要素ごとに書き下すことで、これまでのスカラ値を対象とした計算グラフと同様に考えることができる。

実際に書き下すと、以下の計算グラフに分解できる。

dotノード(要素ごとに分解)

Pythonによる実装

加算ノード

In [1]:
class AddLayer:
    def __init__(self):
        pass

    def forward(self, x, y):
        out = x+y
        return out

    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1

        return dx, dy

乗算ノード

In [2]:
class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y                
        out = x * y

        return out

    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x

        return dx, dy

逆数ノード

In [3]:
class InverseLayer:
    def __init__(self):
        self.out = None

    def forward(self, x):
        out = 1/x
        self.out = out

        return out

    def backward(self, dout):
        dx = dout * (-self.out**2)

        return dx

expノード

In [4]:
class ExpLayer:
    def __init__(self):
        self.out = None

    def forward(self, x):
        out = np.exp(x)
        self.out = out

        return out

    def backward(self, dout):
        dx = dout * self.out

        return dx

dotノード

In [5]:
class DotLayer:
    def __init__(self):
        self.X = None
        self.W = None

    def forward(self, X, W):
        """
        @param X   : shape=(1,a)
        @param W   : shape=(a,b)
        @return out: shape=(1,b)
        """
        self.X = X
        self.W = W
        out = np.dot(X, W) # shape=(1,b)

        return out

    def backward(self, dout):
        """
        @param dout: shape=(1,b)
        @return dX : shape=(1,a)
        @return dW : shape=(a,b)
        """
        dX = np.dot(dout, self.W.T)
        dW = np.dot(self.X.T, dout)

        return dX,dW

参考

In [ ]: