0x00 先上一段代碼
問題: 在自定義網絡層的時候,想要搞清楚
build()
和call()
是用來做什麼的,爲什麼能調用成功,不用外部再定義
# coding=utf-8
'''
@ Summary: test call
@ Update:
@ file: test.py
@ version: 1.0.0
@ Author: [email protected]
@ Date: 2020/6/11 下午3:48
'''
from __future__ import absolute_import, division, print_function
import tensorflow as tf
tf.keras.backend.clear_session()
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
class MyLayer(layers.Layer):
def __init__(self, unit=32):
super(MyLayer, self).__init__()
self.unit = unit
def build(self, input_shape):
self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
initializer=keras.initializers.RandomNormal(),
trainable=True)
self.bias = self.add_weight(shape=(self.unit,),
initializer=keras.initializers.Zeros(),
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.weight) + self.bias
my_layer = MyLayer(3)
x = tf.ones((3,5))
out = my_layer(x)
print(out)
0x01 庖丁解牛1 - init
定義一個類對象
my_layer = MyLayer(3)
ok, 此處沒有任何問題
上面是僅調用了MyLayer()
類中的__init__()
方法,獲得了self.units = 3
這一個變量
此處尚未調用類中的 build()
和 call()
方法
def __init__(self, unit=32):
# 繼承,此處不多說,有一個很有意思的是單繼承和多繼承
super(MyLayer, self).__init__()
self.unit = unit
0x02 庖丁解牛2 – build()
初始化一個輸入對象:
x = tf.ones((3,5))
這一步也是沒有任何問題,繼續往下
out = my_layer(x)
這個地方,問題就來了。
回到最開始的問題:爲什麼不用外部調用就可以運行
build()
、call()
等函數?
回答:在Layer()
類中有一個__call__()
魔法方法(上述兩個函數已經被tf
集成在該函數下面),會被自動調用,因此不用外部調用,具體怎麼個調用過程,請閱讀源碼
接下來就是對my_layer
輸入,輸入爲x
調用build()
方法:
def build(self, input_shape):
self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
initializer=keras.initializers.RandomNormal(),
trainable=True)
self.bias = self.add_weight(shape=(self.unit,),
initializer=keras.initializers.Zeros(),
trainable=True)
初始化兩個可訓練的值,分別是權重和偏置,ok,此部分問題解決了
順帶解決另外一個問題:爲什麼要有
build()
方法
回答: build()
可自定義網絡的權重的維度,可以根據輸入來指定權重的維度,若權重固定,可避免使用build()
方法
另外一個需要注意的地方在於:self.built = True
該參數在build()
運行開始時爲False,爲了保證先調用build()
方法, 再調用call()
方法
結束時會自動賦值爲True
,保證build()
方法只被調用一次
class MyLayer(layers.Layer):
def __init__(self, input_dim=32, unit=32):
super(MyLayer, self).__init__()
self.weight = self.add_weight(shape=(input_dim, unit),
initializer=keras.initializers.RandomNormal(),
trainable=True)
self.bias = self.add_weight(shape=(unit,),
initializer=keras.initializers.Zeros(),
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.weight) + self.bias
0x03 庖丁解牛3 – call()
在調用完了build()
方法之後,獲取到了初始化的權重和偏置值,接下來進行正向傳播,官網上是說實現邏輯功能函數,我更喜歡說成是前者,更好理解
def call(self, inputs):
return tf.matmul(inputs, self.weight) + self.bias
返回該層的輸出值,不包含激活函數計算
0x04 最終輸出
print(out)
0x05 總結
Layer的子類一般實現如下:
-
init():super(), 初始化所有與輸入無關的變量
-
build():用於初始化層內的參數和變量
-
call():定義前向傳播
第一次訓練先計算Model(x)
, 然後計算Model(x).build(input)
,最後計算Model(x).call(input)
,第二次往後就跳過了中間步驟