Pytorch學習(五)--實現神經網絡實例一

1.準備數據

# 導入pytorch內置的minst數據
from torchvision.datasets import mnist
# 導入預處理模塊
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定義一些超參數
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epoches = 20
lr = 0.01
momentum = 0.5

# 下載數據並對數據進行預處理
# 定義預處理函數,這些預處理函數一次放在Compose 函數中
trasform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])
# 下載數據並對數據進行預處理
train_dataset = mnist.MNIST("./data",train=True,transform=trasform,download=False)
test_dataset = mnist.MNIST('./data',train=False,transform=trasform,download=False)
#dataloader是一個可迭代對象 可以使用迭代器一樣使用
train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False)

說明:
1.transforms.Compose可以把一些轉換函數組合在一起
2.Normalize([0.5],[0.5])對張量進行歸一化,這裏兩個0.5分別表示對張量進行歸一化的均值和方差。因圖像是灰色的只有一個通道,如果有多個通道,需要有多個數字,如3個通道,應該是([0.5,0.5,0.5],[0.5,0.5,0.5])
3.download參數控制是否需要下載,如果./data下已有MINST,可選擇False

2.可視化數據源

#可視化數據源
import matplotlib.pyplot as plt

examples = enumerate(test_loader)
batch_idx,(example_data,example_targets)=next(examples)
print((example_data,example_targets))
print(type(example_data),example_data.shape)
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout() # tight_layout會自動調整子圖參數
    # print(example_data[i].shape)
    plt.imshow(example_data[i].squeeze(),cmap='gray',interpolation = 'none') # example_data[i][0]
    # print(example_data[i][0].shape)
    # a = example_data[i].squeeze()
    # print(a.shape)
    plt.title("Ground Truth:{}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

說明:

  1. enumerate()函數用於將一個可遍歷的數據對象組合未一個索引序列,同時列出數據和數據下標
  2. 因爲example_data的size=[128,1,28,28],而example_data[i] = [1,28,28],而imshow輸入灰度圖必須是二維格式,不能是(CHW)。
  3. 而[1,28,28]裏[28,28]只有一個,所以可用[i][0]取到
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章