import torch
torch.__version__
class RNN(object):
def __init__(self,input_size,hidden_size):
super().__init__()
self.W_xh=torch.nn.Linear(input_size,hidden_size) #因爲最後的操作是相加 所以hidden要和output 的shape一致
self.W_hh=torch.nn.Linear(hidden_size,hidden_size)
def __call__(self,x,hidden):
return self.step(x,hidden)
def step(self, x, hidden):
#前向傳播的一步
h1=self.W_hh(hidden)
w1=self.W_xh(x)
out = torch.tanh( h1+w1)
hidden=self.W_hh.weight
return out,hidden
rnn = RNN(20,50)
input = torch.randn( 32 , 20)
h_0 =torch.randn(32 , 50) :
for i in range(seq_len):
output,hn= rnn(input[i, :], h_0)
print(output.size(),h_0.size())
seq_len = input.shape[0]
用 pytorch 實現 一個rnn
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.