[ pytorch ] —— 基本使用:(9) 自定義反向傳播

參考文章:知乎-Pytorch筆記04-自定義torch.autograd.Function

import torch
from torch.autograd import Variable

class MyReLU(torch.autograd.Function):

    def forward(self, input_):
        # 在forward中,需要定義MyReLU這個運算的forward計算過程
        # 同時可以保存任何在後向傳播中需要使用的變量值
        self.save_for_backward(input_)         # 將輸入保存起來,在backward時使用
        output = input_.clamp(min=0)               # relu就是截斷負數,讓所有負數等於0
        return output

    def backward(self, grad_output):
        # 根據BP算法的推導(鏈式法則),dloss / dx = (dloss / doutput) * (doutput / dx)
        # dloss / doutput就是輸入的參數grad_output、
        # 因此只需求relu的導數,在乘以grad_outpu    
        input_, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0                # 上訴計算的結果就是左式。即ReLU在反向傳播中可以看做一個通道選擇函數,所有未達到閾值(激活值<0)的單元的梯度都爲0
        return grad_input

我給上面代碼加了一個形象的圖:
在這裏插入圖片描述
vsd文件備份

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章