在使用梯度下降法进行回归时,需要频繁的进行偏导数的计算。在很多的相关介绍中会展示使用计算图进行偏导数的计算。这里简述对该方法的一些理解。
1. 概述
- 计算图求导,可以理解为是对求导的链式法则的图表示
- 在计算图中,在一个单向路径上的算子,求导时,将各个导数相乘即可
- 在计算图中,在一个单向路径上,上一个节点的输出,是下一个节点的输入;函数关系上,就是 \(f(g(x)) \),即 \(g(x) \) 的输出是,\(f(x) \)的输入
- 一个节点的两个入度(分支),求导时,将各个导数相乘即可
- 多元函数求偏导时,只需要关注其偏导的变量即可
2. 链式法则的典型形式
这里对求导的链式法则的典型形式做一个简单的回顾。
在对复杂的表达式求导/微分时,有时候看起来会很复杂。如果能够灵活的使用链式法则可以巧妙将复杂函数的求导转换为简单函数的求导。
2.1 法则1
$$
f(x) = g(h(x))
\\
f'(x) = \frac{\partial f}{\partial x} = \frac{\partial g}{\partial h} \frac{\partial h}{\partial x}
$$
例如,使用该法则可以很简单对如下函数求导:
$$
f(x) = e^{(x^2)}
\\
g(h) = e^h \, h(x) = x^2
\\
f'(x) = \frac{\partial g}{\partial h} \frac{\partial h}{\partial x} = e^h * 2 * x = 2x*e^h = 2xe^{x^2}
$$
如果使用计算图的方式表达如上的求导,如下:

$$
f(x) = f(g(x))
\\
\frac{\partial f}{\partial x} = \frac{\partial f}{\partial g}\frac{\partial g}{\partial h}
$$
所以:在计算图中,在一个单向路径上的算子,求导时,将各个导数相乘即可。
(more…)