pytorch筆記7--批訓練

import torch
import torch.utils.data as Data    #用於小批訓練
torch.manual_seed(1)   #爲cpu設置隨機種子,使多次運行結果一致
# torch.cuda.manual_seed(seed)  #爲當前GPU設置隨機種子
#torch.cuda.manual_seed_all(seed)  #爲所有GPU設置隨機種子

Batch_Size = 4
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

#用DataLoader來包裝數據,用於批訓練(首先將數據轉換爲torch能識別的Dataset形式)
torch_dataset=Data.TensorDataset(x,y)

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=Batch_Size,
    shuffle=True,   #是否打亂數據
)

for epoch in range(3): #將所有數據訓練3次
    for step,(batch_x,batch_y) in enumerate(loader): #每一步loader釋放一小批數據

        ...
        print('Epoch:{}, Step:{}, batch x:{}, batch y:{}'.format(epoch,step,batch_x.numpy(),batch_y.numpy()))

結果:

Epoch:0, Step:0, batch x:[ 5.  7. 10.  3.], batch y:[6. 4. 1. 8.]
Epoch:0, Step:1, batch x:[4. 2. 1. 8.], batch y:[ 7.  9. 10.  3.]
Epoch:0, Step:2, batch x:[9. 6.], batch y:[2. 5.]
Epoch:1, Step:0, batch x:[ 4.  6.  7. 10.], batch y:[7. 5. 4. 1.]
Epoch:1, Step:1, batch x:[8. 5. 3. 2.], batch y:[3. 6. 8. 9.]
Epoch:1, Step:2, batch x:[1. 9.], batch y:[10.  2.]
Epoch:2, Step:0, batch x:[4. 2. 5. 6.], batch y:[7. 9. 6. 5.]
Epoch:2, Step:1, batch x:[10.  3.  9.  1.], batch y:[ 1.  8.  2. 10.]
Epoch:2, Step:2, batch x:[8. 7.], batch y:[3. 4.]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章