獲取smbol的內部節點(獲取部分網絡結構)

首先定義一個網絡

data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000)
act = mx.sym.Activation(data=fc1, act_type='relu',name='act')
fc2 = mx.sym.FullyConnected(data=act, name='fc2', num_hidden=10)
net = mx.sym.SoftmaxOutput(fc2,name="softmax")
net.save('model.symbol.json')

將該網絡通過save序列化保存爲json文件。 # mx.sym.Symbol 類型自帶save函數

載入json,獲取內部節點

net = mx.sym.load('model.symbol.json')
net.get_internals  #  獲取所有內部節點,將其group起來
<Symbol group [data, fc1_weight, fc1_bias, fc1, act, fc2_weight, fc2_bias, fc2, softmax_label, softmax]>
net.get_internals().list_outputs()  #  獲取所有內部節點的輸出
['data',
 'fc1_weight',
 'fc1_bias',
 'fc1_output',
 'act_output',
 'fc2_weight',
 'fc2_bias',
 'fc2_output',
 'softmax_label',
 'softmax_output']
new_net = net.get_internals()['act_output'] #  通過key value方式,獲取網絡的內部節點數據,後續可以接着對new_net進行操作,構建新的網絡

對於var,其list_outputs()list_inputs()都是他自己,對於其他op,比如+,-,relu等,其輸出要在該op的name後面加上_output後綴

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章