參考文章:知乎-Pytorch筆記04-自定義torch.autograd.Function
import torch
from torch.autograd import Variable
class MyReLU(torch.autograd.Function):
def forward(self, input_):
# 在forward中,需要定義MyReLU這個運算的forward計算過程
# 同時可以保存任何在後向傳播中需要使用的變量值
self.save_for_backward(input_) # 將輸入保存起來,在backward時使用
output = input_.clamp(min=0) # relu就是截斷負數,讓所有負數等於0
return output
def backward(self, grad_output):
# 根據BP算法的推導(鏈式法則),dloss / dx = (dloss / doutput) * (doutput / dx)
# dloss / doutput就是輸入的參數grad_output、
# 因此只需求relu的導數,在乘以grad_outpu
input_, = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0 # 上訴計算的結果就是左式。即ReLU在反向傳播中可以看做一個通道選擇函數,所有未達到閾值(激活值<0)的單元的梯度都爲0
return grad_input
我給上面代碼加了一個形象的圖:
vsd文件備份