@tf.custom_gradient

@tf.custom_gradient

初衷

網上資料較少,而且官方文檔比較ambigious(也許有誤),花了比較久的時間搞懂這個修飾器,記此貼防止大家走彎路。

官方文檔
參考文檔

介紹

@tf.custom_gradient

裝飾器允許控制對梯度的一連串操作,這樣做的好處是對梯度操作提供一種更有效率更穩定方式。

考慮一種情況
在這裏插入圖片描述
由於數值不穩定性,x=100處的梯度(f=fxi\bigtriangledown f=\frac{\partial f}{\partial x}\vec i)由函數得到的值爲NanNan

在這裏插入圖片描述

解決方法

使用@custom_gradient,梯度表達式可以被解析簡化,以提供數值穩定性
在這裏插入圖片描述
可以推斷@tf.custom_gradient的
  args爲xx,
  returns爲y,yxy,\frac{\partial y}{\partial x}的函數形式
一方面調用y=log1exp(x),可以得到y=y
另一方面調用grady=gradient(y,x),可以得到grady=yx\frac{\partial y}{\partial x}

於是對於二階導
只需要定義一階導的嵌套形式,使用@custom_gradient修飾一階導並使其返回y對x的一階導以及y對x二階導對應的函數

代碼如下
@tf.custom_gradient
def log1pexp2(x):
    e = tf.exp(x)
    y = tf.math.log(1 + e)
    x_grad = 1 - 1 / (1 + e)
    def first_order_gradient(dy):
        @tf.custom_gradient
        def first_order_custom(unused_x):
            def second_order_gradient(ddy):
                # Let's define the second-order gradient to be (1 - e)
                return ddy * (1 - e) 
            return x_grad, second_order_gradient
        return dy * first_order_custom(x)
    return y, first_order_gradient

以上二階導不是真實的二階導(爲了便於檢測)

測試代碼如下
import tensorflow as tf

@tf.custom_gradient
def log1pexp2(x):
    e = tf.exp(x)
    y = tf.math.log(1 + e)
    x_grad = 1 - 1 / (1 + e)
    def first_order_gradient(dy):
        @tf.custom_gradient
        def first_order_custom(unused_x):
            def second_order_gradient(ddy):
                # Let's define the second-order graidne to be (1 - e)
                return ddy * (1 - e) 
            return x_grad, second_order_gradient
        return dy * first_order_custom(x)
    return y, first_order_gradient

x1 = tf.constant(1.)
y1 = log1pexp2(x1)
dy1 = tf.gradients(y1, x1)
ddy1 = tf.gradients(dy1, x1)

x2 = tf.constant(100.)
y2 = log1pexp2(x2)
dy2 = tf.gradients(y2, x2)
ddy2 = tf.gradients(dy2, x2)

with tf.Session() as sess:
    print('x=1, dy1:', dy1[0].eval(session=sess))
    print('x=1, ddy1:', ddy1[0].eval(session=sess))
    print('x=100, dy2:', dy2[0].eval(session=sess))
    print('x=100, ddy2:', ddy2[0].eval(session=sess))
運行結果
x=1, dy1: 0.7310586
x=1, ddy1: -1.7182817
x=100, dy2: 1.0
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章