tensorflow2 中關於自定義層的build() 和 call()一點探究

0x00 先上一段代碼

問題: 在自定義網絡層的時候,想要搞清楚build()call() 是用來做什麼的,爲什麼能調用成功,不用外部再定義


# coding=utf-8
'''
@ Summary: test call
@ Update:  

@ file:    test.py
@ version: 1.0.0

@ Author:  [email protected]
@ Date:    2020/6/11 下午3:48
'''
from __future__ import absolute_import, division, print_function
import tensorflow as tf
tf.keras.backend.clear_session()
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

class MyLayer(layers.Layer):
   def __init__(self, unit=32):
       super(MyLayer, self).__init__()
       self.unit = unit

   def build(self, input_shape):
       self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                     initializer=keras.initializers.RandomNormal(),
                                     trainable=True)
       self.bias = self.add_weight(shape=(self.unit,),
                                   initializer=keras.initializers.Zeros(),
                                   trainable=True)

   def call(self, inputs):
       return tf.matmul(inputs, self.weight) + self.bias

my_layer = MyLayer(3)
x = tf.ones((3,5))
out = my_layer(x)
print(out)

0x01 庖丁解牛1 - init

定義一個類對象


my_layer = MyLayer(3)

ok, 此處沒有任何問題

上面是僅調用了MyLayer() 類中的__init__() 方法,獲得了self.units = 3 這一個變量

此處尚未調用類中的 build()call() 方法



   def __init__(self, unit=32):

       # 繼承,此處不多說,有一個很有意思的是單繼承和多繼承
       super(MyLayer, self).__init__() 
       self.unit = unit

0x02 庖丁解牛2 – build()

初始化一個輸入對象:


x = tf.ones((3,5))

這一步也是沒有任何問題,繼續往下


out = my_layer(x)

這個地方,問題就來了。

回到最開始的問題:爲什麼不用外部調用就可以運行build()call()等函數?

回答:在Layer() 類中有一個__call__() 魔法方法(上述兩個函數已經被tf集成在該函數下面),會被自動調用,因此不用外部調用,具體怎麼個調用過程,請閱讀源碼

接下來就是對my_layer 輸入,輸入爲x

調用build() 方法:


   def build(self, input_shape):
       self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                     initializer=keras.initializers.RandomNormal(),
                                     trainable=True)
       self.bias = self.add_weight(shape=(self.unit,),
                                   initializer=keras.initializers.Zeros(),
                                   trainable=True)

初始化兩個可訓練的值,分別是權重和偏置,ok,此部分問題解決了

順帶解決另外一個問題:爲什麼要有build() 方法

回答: build() 可自定義網絡的權重的維度,可以根據輸入來指定權重的維度,若權重固定,可避免使用build() 方法

另外一個需要注意的地方在於:self.built = True

該參數在build() 運行開始時爲False,爲了保證先調用build() 方法, 再調用call() 方法

結束時會自動賦值爲True,保證build() 方法只被調用一次


class MyLayer(layers.Layer):

    def __init__(self, input_dim=32, unit=32):

        super(MyLayer, self).__init__()

        self.weight = self.add_weight(shape=(input_dim, unit),

                                     initializer=keras.initializers.RandomNormal(),

                                     trainable=True)

        self.bias = self.add_weight(shape=(unit,),

                                   initializer=keras.initializers.Zeros(),

                                   trainable=True)

    

    def call(self, inputs):

        return tf.matmul(inputs, self.weight) + self.bias

0x03 庖丁解牛3 – call()

在調用完了build() 方法之後,獲取到了初始化的權重和偏置值,接下來進行正向傳播,官網上是說實現邏輯功能函數,我更喜歡說成是前者,更好理解


   def call(self, inputs):
       return tf.matmul(inputs, self.weight) + self.bias

返回該層的輸出值,不包含激活函數計算

0x04 最終輸出


print(out)

0x05 總結

Layer的子類一般實現如下:

  • init():super(), 初始化所有與輸入無關的變量

  • build():用於初始化層內的參數和變量

  • call():定義前向傳播

第一次訓練先計算Model(x), 然後計算Model(x).build(input),最後計算Model(x).call(input),第二次往後就跳過了中間步驟

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章