Pytorch學習筆記之入門實戰之用pytorch玩FizzBuzz(二)

Pytorch學習筆記之入門實戰之用pytorch玩FizzBuzz(二)

環境說明

from __future__ import print_function
import torch
torch.__version__
'1.4.0'

FizzBuzz

FizzBuzz是一個簡單的小遊戲。遊戲規則如下:從1開始往上數數,當遇到3的倍數的時候,說fizz,當遇到5的倍數,說buzz,當遇到15的倍數,就說fizzbuzz,其他情況下則正常數數。

我們可以寫一個簡單的小程序來決定要返回正常數值還是fizz, buzz 或者 fizzbuzz。

# one-hot encode the desired outpuss:[number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if i % 15 == 0: return 3
    elif i%5 == 0: return 2
    elif i%3 == 0: return 1
    else: return 0
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
for i in range(1, 16):
    print(fizz_buzz_decode(i, fizz_buzz_encode(i)))

運行結果

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz

定義模型的輸入與輸出(訓練數據)

import numpy as np
import torch

NUM_DIGIT = 10

# Represent each input by an array of its binary digits
# 將輸入轉換成二進制表示,除2取餘再反向
def binary_encode(i, num_digits):
    return np.array([i>>d & 1 for d in range(num_digits)][::-1])
# 
all_data_x = torch.Tensor([binary_encode(i, NUM_DIGIT) for i in range(1, 2 ** NUM_DIGIT)])
all_data_y = torch.LongTensor([fizz_buzz_encode(i) for i in range(1, 2**NUM_DIGIT)])
if torch.cuda.is_available():
    all_data_x = all_data_x.cuda()
    all_data_y = all_data_y.cuda()

trX = all_data_x[101:] # 922*10
trY = all_data_y[101:] # 922
testX = all_data_x[:100] # 100*10
testY = all_data_y[:100] # 100
print(trX[0], trX.shape)
print(testX[0], testX.shape)
print(testY[0], testY.shape)

用PyTorch定義模型

# Define the model
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)
if torch.cuda.is_available():
    model = model.cuda()
  • 爲了讓我們的模型學會FizzBuzz這個遊戲,我們需要定義一個損失函數,和一個優化算法。
  • 這個優化算法會不斷優化(降低)損失函數,使得模型的在該任務上取得儘可能低的損失值。
  • 損失值低往往表示我們的模型表現好,損失值高表示我們的模型表現差。
  • 由於FizzBuzz遊戲本質上是一個分類問題,我們選用Cross Entropyy Loss函數。
  • 優化函數我們選用Stochastic Gradient Descent。
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

開始訓練模型

BATCH_SIZE = 128
for epoch in range(10001):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]
        #   
        y_pred = model.forward(batchX)
        loss = loss_fn(y_pred, batchY)
        # 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 1000 ==0:
        loss = loss_fn(model(trX), trY).item()
        print("Epoch {} Loss {}".format(epoch, loss))

訓練過程

Epoch 0 Loss 1.5171986818313599
Epoch 1000 Loss 0.033820319920778275
Epoch 2000 Loss 0.018580008298158646
Epoch 3000 Loss 0.012838841415941715
Epoch 4000 Loss 0.009773609228432178
Epoch 5000 Loss 0.007854906842112541
Epoch 6000 Loss 0.006554984021931887
Epoch 7000 Loss 0.005604262929409742
Epoch 8000 Loss 0.004887698218226433
Epoch 9000 Loss 0.004323565401136875
Epoch 10000 Loss 0.0038704578764736652

最後我們用訓練好的模型嘗試在1到100這些數字上玩FizzBuzz遊戲

with torch.no_grad():
    resultY = model(testX)
predictions = zip(range(1,101), resultY.max(1)[1].data.tolist())
print([(i,fizz_buzz_decode(i, x)) for (i, x) in predictions])

預測結果

[(1, '1'), (2, '2'), (3, 'fizz'), (4, '4'), (5, 'buzz'), (6, 'fizz'), (7, '7'), (8, '8'), (9, 'fizz'), (10, 'buzz'), (11, '11'), (12, 'fizz'), (13, '13'), (14, '14'), (15, 'fizzbuzz'), (16, '16'), (17, '17'), (18, 'fizz'), (19, '19'), (20, 'buzz'), (21, 'fizz'), (22, '22'), (23, '23'), (24, 'fizz'), (25, 'buzz'), (26, '26'), (27, 'fizz'), (28, '28'), (29, '29'), (30, 'fizzbuzz'), (31, '31'), (32, '32'), (33, 'fizz'), (34, '34'), (35, 'buzz'), (36, 'fizz'), (37, '37'), (38, '38'), (39, 'fizz'), (40, 'buzz'), (41, '41'), (42, 'fizz'), (43, '43'), (44, '44'), (45, 'fizzbuzz'), (46, '46'), (47, '47'), (48, 'fizz'), (49, '49'), (50, 'buzz'), (51, 'fizz'), (52, '52'), (53, '53'), (54, 'fizz'), (55, 'buzz'), (56, '56'), (57, 'fizz'), (58, '58'), (59, '59'), (60, 'fizzbuzz'), (61, '61'), (62, '62'), (63, 'fizz'), (64, '64'), (65, '65'), (66, 'fizz'), (67, '67'), (68, '68'), (69, '69'), (70, 'buzz'), (71, '71'), (72, 'fizz'), (73, '73'), (74, '74'), (75, 'fizzbuzz'), (76, '76'), (77, '77'), (78, 'fizz'), (79, '79'), (80, 'buzz'), (81, 'fizz'), (82, '82'), (83, '83'), (84, '84'), (85, 'buzz'), (86, '86'), (87, 'fizz'), (88, '88'), (89, '89'), (90, 'fizzbuzz'), (91, '91'), (92, '92'), (93, '93'), (94, '94'), (95, 'buzz'), (96, 'fizz'), (97, '97'), (98, '98'), (99, 'fizz'), (100, 'buzz')]

查看準確率

print(np.sum(resultY.cpu().max(1)[1].numpy() == testY.cpu().numpy())/len(testY))
print(resultY.cpu().max(1)[1].numpy() == testY.cpu().numpy())

保存模型參數

torch.save(model.state_dict(), 'params1.pkl')
# 加載模型參數
# model.load_state_dict(torch.load('params.pkl'))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章