深度學習筆記--Fashion_mnist+softmax的簡單實現

1. 主文件

import torch
import All_function as func
import torch.nn as nn

#獲取數據
batch_size=256
train_iter,test_iter=func.load_data_fashion_minist(256)

#定義和初始化模型
class LinearNet(nn.Module):
    def __init__(self,inputs,outputs):
        super(LinearNet, self).__init__()
        self.linear=nn.Linear(inputs,outputs)
    def forward(self,x):
        pred_y=self.linear(x.view(x.shape[0],-1))    #x在輸入網絡之前,對其進行形狀轉換
        return pred_y                                #數據返回的每個batch樣本x的形狀爲(batch_size, 1, 28, 28),
                                                     #所以要先用view()將x的形狀轉換成(batch_size, 784)才送入全連接層。
    
num_inputs,num_outputs=784,10
net=LinearNet(num_inputs,num_outputs)

#訓練模型
epochs=5
loss_func=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(net.parameters(),lr=0.01)
func.train_model(net,train_iter,test_iter,loss_func,epochs,optimizer)

#顯示部分測試結果
X,y=iter(test_iter).next()
true_labels=func.get_fashion_mnist_labels(y.numpy())
pred_labels=func.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles=[true+'\n'+pred for true,pred in zip(true_labels,pred_labels)]
func.show_fashion_mnist(X[0:10],titles[0:10])

2. 封裝函數的py文件:All_function

import torch
import torchvision
import torchvision.transforms as transforms
import sys
from IPython import display
import matplotlib.pyplot as plt

#標籤轉換,因爲數據集裏的標籤是0-9的數字
def get_fashion_mnist_labels(labels):
    text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)]for i in labels]

#獲取數據
def load_data_fashion_minist(batch_size):
    #如果沒有數據集,則download=True;用train標記數據集爲訓練集還是測試集
    mnist_train=torchvision.datasets.FashionMNIST(root='D:/Tasks/DeepLearningInAction_pytorch/DraftNumberDetection/Datasets/FashionMNIST',train=True,download=False,transform=transforms.ToTensor())
    mnist_test=torchvision.datasets.FashionMNIST(root='D:/Tasks/DeepLearningInAction_pytorch/DraftNumberDetection/Datasets/FashionMNIST',train=False,download=False,transform=transforms.ToTensor())
    if sys.platform.startswith('win'):
        num_workers=0     #0表示不用額外的進程來加速讀取數據
    else:
        num_workers=4
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=num_workers)  # shuffle每次迭代訓練時是否重新洗牌
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_iter, test_iter

#測試精度
def evaluate_accuarcy(data_iter,net):
    acc_sum=0.0
    n=0
    for X,y in data_iter:
        acc_sum +=(net(X).argmax(dim=1)==y).float().sum().item()   #.sum()之後的結果是tensor型數據;.item()將數據轉換爲numpy類型
        n =n+ y.shape[0]
    return acc_sum/n
#訓練模型
def train_model(net,train_iter,test_iter,loss_func,num_epochs,optimizer=None):
    print('epoch is %d'%num_epochs)
    for epoch in range(num_epochs):
        train_loss_sum, train_acc_sum, n=0.0,0.0,0
        for X,y in train_iter:
            y_hat=net(X)
            loss=loss_func(y_hat,y).sum()   #計算誤差

            optimizer.zero_grad()
            loss.backward()          #誤差反向傳播
            optimizer.step()

            train_loss_sum+=loss.item()
            train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
            n+=y.shape[0]

        test_acc=evaluate_accuarcy(test_iter,net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_loss_sum/n, train_acc_sum / n, test_acc))

#圖片顯示
def use_svg_display():
    #svg:Scalable Vector Graphicas 可縮放矢量圖
    """Use svg format to display plot in jupyter"""
    display.set_matplotlib_formats('svg')   #用矢量圖顯示圖片

def show_fashion_mnist(images, labels):   #在一行裏畫出多張圖片和對應標籤
    use_svg_display()
    # 這裏的_表示忽略(不使用)的變量
    _, figs = plt.subplots(1, len(images), figsize=(15, 15))  #將畫布分成1行len(images)列,每張圖片大小以(15,15)的大小顯示
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)    #隱藏座標軸
        f.axes.get_yaxis().set_visible(False)
    plt.show()

運行結果:

epoch is 5
epoch 1, loss 0.0054, train acc 0.631, test acc 0.684
epoch 2, loss 0.0036, train acc 0.717, test acc 0.725
epoch 3, loss 0.0031, train acc 0.751, test acc 0.747
epoch 4, loss 0.0029, train acc 0.769, test acc 0.761
epoch 5, loss 0.0028, train acc 0.781, test acc 0.771

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章