機器如何計算微分/偏微分(上)

提到近年來熱門的類神經網路模型,反傳導(back propagation)是一個相當重要的過程,其中應用的就是計算梯度(gradient,就是每個參數的偏微分)修正神經參數來達到「學習」的效果,類神經網路的結構越來越複雜,使得人工推導梯度公式越來越不可行,如何應用電腦計算微分變成了一個重要的問題。

這張圖簡單地解釋了計算微分的幾種作法。摘自“Automatic Differentiation of Algorithms for Machine Learning”
這張圖簡單地解釋了計算微分的幾種作法。摘自“Automatic Differentiation of Algorithms for Machine Learning”

Numerical Differential

{-} 數值微分, By Olivier Cleynen (Own work) [CC0], via Wikimedia Commons

數值微分就是用 \(\frac{f(x + h) - f(x)}{h}\) 跟一個很小的 \(h\)近似 \(\lim_{h \rightarrow 0}\frac{f(x + h) - f(x)}{h}\),實作上也只是計算用不同的參數來計算函數 \(f\) 的差異,但是 \(f\) 的計算成本可能不低,而且很容易遇到浮點數計算的 round-off error 跟 truncating error,所以實務上並不實用。

實作

Symbolic Differential

符號微分是符號計算 (Symbolic Computing) 的一個經典例子:不去計算數值,而是 跟人類一樣 直接處理這些符號。 我最早是在 SICP 一書看見這個作法,可以忍受很多括號的讀者可以去看一下 §2.3.2 的介紹。

在修習微積分課程時,一定會提到如何利用 Differentiation rules 來求出導數,像是:

而 Symbolic Differential 就是採用跟人類一樣的方法,對著 expression tree 不斷的做 pattern matching 跟 rewrite。這個方法很好實作,只需要照著 Differential rules 改寫便是,而且可以得到真正的Exact Derivative,也就是跟手動推導的結果一致,Numerical Differential 因為浮點數的特性,會產生誤差而無法算出真正的導數/偏導數

導數/偏導數。

但是這個方法最大的問題是在改寫的過程中,會出現很多冗於的式子: \(x \times 1\), \(x + 0\) 之類的,因此很容易產生出巨大的(沒效率)的結果。“Automatic Differentiation in Machine Learning: a Survey” 一文中舉例到: \[f(x) = 64x(1 - x)(1 - 2x)^2(1 - 8x + 8x^2)^2\] 在沒有化簡的情況做符號微分會得到 \[\begin{gathered} 128x(1 - x)(-8 + 16x)(1 - 2x)^2(1 - 8x+8x^2) \\ + 64(1-x)(1-2x)^2 (1-8x+ 8x^2)^2\\ - 64x(1-2x)^2 (1-8x+8x^2)^2 \\ - 256x(1 - x)(1 - 2x)(1 - 8x + 8x^2)^2 \end{gathered} \] 這條很複雜的式子,但是經過化簡則只剩下 \[\begin{gathered} 64(1 - 42x + 504x^2 - 2640x^3 \\ + 7040x^4 -9984x^5 + 7168x^6 -2048x^7) \end{gathered}\],而隨著原式\(f\)的複雜度上升,沒做好簡化的符號微分會膨脹非常快。

以 Python 來說,SymPy 就實作了符號積分,可以看到 diff 會產出一個新的運算式。

Forward-Mode Automatic Differential

自動微分(AD)的名稱取得很容易讓人誤解,其實就是應用 chain rule 跟一些語言特性或是原始碼轉換工具,來達成不用手寫導數的程式,而可以在計算函數的同時(也就是 overhead 不大)得出真正的導數/偏導數。

自動微分根據計算的順序,可以分為 Forward Mode 跟 Reverse Mode,這次先介紹比較簡單的 Forward Mode。

{-} Forward Mode AD 將一個函數 \(y = f(... , x, ...) = (w_m \circ w_{m-1} ... \circ w_1)(..., x, ...)\)\(x\) 的微分,拆解成 \[ \begin{aligned} \frac {\partial y}{\partial x} =&\frac {\partial y}{\partial w_{m-1}}{\frac {\partial w_{m-1}}{\partial x}} \\ =&\frac {\partial y}{\partial w_{m-1}}\left({\frac {\partial w_{m-1}}{\partial w_{m-2}}}{\frac {\partial w_{m-2}}{\partial x}}\right) \\ =&\frac {\partial y}{\partial w_{m-1}}\left({\frac {\partial w_{m-1}}{\partial w_{m-2}}}\left({\frac {\partial w_{m-2}}{\partial w_{m-3}}}{\frac {\partial w_{m-3}}{\partial x}}\right)\right) \\ =&{\frac {\partial y}{\partial w_{m-1}}} \left( \frac {\partial w_{m-1}}{\partial w_{m-2}} \cdots \left(\frac {\partial w_{2}}{\partial w_{1}} \frac {\partial w_{1}}{\partial x} \right)\right)\\ \end{aligned} \] 然後從 \(\frac{\partial x}{\partial x} = 1\) 開始, 由 \(\frac{\partial w_1}{\partial x}\) 開始計算到 \(\frac{\partial w_m}{x}\)(由內而外)。

這邊繼續使用上面的 \(Y = f(x, y) = sin(x) + x × y\) 的例子,

我們可以把 \(Y\) 拆解成下圖(左)

\[ \begin{array}{rl|rl} Y & = w_5 & \dot{Y} & = \dot{w_5} \\ w_5 & = w_3 + w_4 & \dot{w_5} & = \dot{w_3} + \dot{w_4} \\ w_4 & = w_1 × w_2 & \dot{w_4} & = w_2 × \dot{w_1} + w_1 × \dot{w_2}\\ w_3 & = sin(w_1) & \dot{w_3} & = cos(w_1) \\ w_2 & = y & \dot{w_2} & = 0\\ w_1 & = x & \dot{w_1} & = 1\\ \end{array} \]

然後我們就可以從 \(\frac{\partial x}{\partial x} = 1, \frac{\partial y}{\partial x} = 0\) 開始,求出 \(\frac{\partial w_i}{\partial x} \big \vert_{i = 1 \cdots 5}\),最後求出\(\frac{\partial Y}{\partial x}\) (如上圖(右),這邊用 \(\dot{X}\) 來表示 \(\frac{\partial X}{\partial x}\))。

從計算順序上,Forward-Mode AD 跟原先函數是一樣的,所以被稱為 Forward Mode。

實作

一個常見的實作方法是定義一個新的資料結構 Dual Number,跟利用 operator overloading 來實作其算術系統。下面使用 Python 舉例較完整的 Dual Number 定義可以在 Wikipedia 找到。如同文中寫到,我們可以使用 \(g(\langle u,u' \rangle , \langle v,v' \rangle ) = \langle g(u,v) , g_u(u,v) u' + g_v(u,v) v' \rangle\) 來定義這些 primitive functions(\(sin, cos, \cdots\))。

為了可以寫出像是 4 * Dual(1, 1) 這樣混合著 floatDual 的程式,可以參考 Python Reference §3.3.7 使用 __radd__ 系列的 method 來實作。

舉例來說,我們可以透過控制Dual(x, dx)dx來分別計算 \(\frac{\partial f}{\partial x}\)\(\frac{\partial f}{\partial y}\)

這是 Wikipedia 上面對 \(Y = f(x,y) = sin(x) + x × y\) 做 forward-mode AD 的示意圖。
By Berland at English Wikipedia [Public domain], via Wikimedia Commons

\[f(x,y) = sin(x) + x*y\]

Dual(x, 1) 對應到 \(\frac{dx}{dx} = 1\)Dual(c, 0) 對應到 \(\frac{dc}{dx} = 0\)

這是怎麼辦到的?我們可以用類似 Equational ReasoningEquational Reasoning 就像數學證明一樣,比較常用在 Pure Functional 的程式語言上,不過既然這個例子也幾乎是 Pure Functional 的,在這邊就使用這個技巧來說明。

的方式來展開:

順帶一提,我們不需要額外實作 Chain Rule,因為它已經包含在 Dual Number 的算術系統內了。

可以看到第 3 行的 dx 正好等於應用 Chain Rule 求出的 \(\frac{\partial sin(sin(x))}{\partial x} = cos(x) × cos(sin(x))\)
這是因為在 sin(dual) 的定義中 Dual(math.sin(dual.x), dual.dx * math.cos(dual.x)) 裡面就已經 inline chain rule 了

這個方法的問題是每一個參數\(x_i\)都必須算過一次,類神經的梯度來說需要\(O(n)\)的時間才能算完(\(n\)是參數的個數),並不是很適合類神經網路計算梯度(因為\(n\)通常都不小)。

小結

本文介紹了

下一篇文章會介紹 Reverse-Mode Automatic Differential,是 Forward-Mode 相反方向的版本,同時也是 back propagation 的 general 版本。

References

Appendix

Example: Logistic Regression

這邊用挑戰者號的O型環失效資料 選這份資料是因為我最近在讀 “Probabilistic Programming & Bayesian Methods for Hackers” 剛好介紹到這一份資料集

作為範例,以上述的 Forward-Mode AD 注意: 這邊的 forwardad 已經有實作 __radd__, exp, log…等滿足下面程式所需的最小需求。

實作 Gradient Descent 來進行 Logistic Regression。

這邊使用 logistic function 做模型,binary cross-entropy 作為 loss function。

\[ \begin{gathered} f(x) = \frac{1}{1+ exp(\beta x + \alpha)} \\ Cost(y, \hat y) = -\frac{1}{N}\sum_{i=1}^N \bigg[ y_i \log(\hat{y}_i)+(1-y_i) \log(1-\hat{y}_i) \bigg] \end{gathered} \] 這邊可以看到,寫出來的程式幾乎沒有為了計算微分而改變,而是照著本來的定義在寫。

最後使用 Gradient Descent 來找出 \(\alpha, \beta\)

alpha, beta, lr, n_epoch 這些參數都是可以設定的。

受限於 Forward-Mode AD 先天上的限制,\(\alpha\)\(\beta\) 的微分需要分開計算。

這邊是上面程式執行的結果(重新排版過),可以看到 Learning Rate 設定的有點太大了。有興趣的讀者可以試著自己實作看看。

    0 loss 0.693147180559945  alpha  0                   beta  0
  100 loss 0.966931426916483  alpha -0.0201008838271800  beta  0.0465418617995232
  200 loss 1.114664696718874  alpha -0.0417346760460772  beta -0.0170795861942109
  300 loss 3.558155059566457  alpha -0.0595466871292758  beta  0.1800165252346373
  400 loss 0.961973323155755  alpha -0.0821799396148940  beta  0.0472951650365717
  500 loss 1.110644944663376  alpha -0.1037522632990092  beta -0.0161308641984885
  600 loss 3.555985812065263  alpha -0.1215047609258659  beta  0.1809462252622945
  700 loss 0.957097575658038  alpha -0.1440822016889158  beta  0.0480494092224557
  800 loss 1.106712443731962  alpha -0.1655931706446121  beta -0.0151865701942648
  900 loss 3.553772437842705  alpha -0.1832862200585766  beta  0.1818707335807102