用 pytorch 實現 一個rnn

原文鏈接

import torch
torch.__version__

class RNN(object):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.W_xh=torch.nn.Linear(input_size,hidden_size) #因爲最後的操作是相加 所以hidden要和output 的shape一致
        self.W_hh=torch.nn.Linear(hidden_size,hidden_size)
        
    def __call__(self,x,hidden):
        return self.step(x,hidden)
    def step(self, x, hidden):
        #前向傳播的一步
        h1=self.W_hh(hidden)
        w1=self.W_xh(x)
        out = torch.tanh( h1+w1)
        hidden=self.W_hh.weight
        return out,hidden
rnn = RNN(20,50)
input = torch.randn( 32 , 20)
h_0 =torch.randn(32 , 50) :

for i in range(seq_len):
    output,hn= rnn(input[i, :], h_0)
print(output.size(),h_0.size())
seq_len = input.shape[0]


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章