pytorch maxout實現

簡述

看了半天,在網上沒有看到pytorch關於maxout的實現。(雖然看到其他的模型也是可以用的,但是爲了更好的復現論文,這裏還是打算實現下)。

(不一定保證完全正確,估計很快pytorch就會自己更新,對應的maxout激活函數了吧?我看到github上好像有對應的issue了都)

maxout的原理也很簡單:簡單來說,就是多個線性函數的組合。然後在每個定義域上都取數值最大的那個線性函數,看起來就是折很多次的折線。(初中數學emmm)

實現

from torch.nn import init
import torch.nn.functional as F
from torch._jit_internal import weak_module, weak_script_method
from torch.nn.parameter import Parameter
import math


@weak_module
class Maxout(nn.Module):
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, pieces, bias=True):
        super(Maxout, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.pieces = pieces
        self.weight = Parameter(torch.Tensor(pieces, out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(pieces, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        output = input.matmul(self.weight.permute(0, 2, 1)).permute((1, 0, 2)) + self.bias
        output = torch.max(output, dim=1)[0]
        return output

如果也喜歡研究源碼的小夥伴就會發現了,我就是在原來的Linear()的源碼基礎上多改進了一個維度而已。

技巧還是在那個維度切換那裏,其他都沒啥,用自己這個試了下,效果還行(不虧是我,叉腰.jpg)

調用的方式也很簡單,就是平常寫的那些nn.Linear() 的方式很像。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章