Mnist手寫數字自編碼+分類實驗

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import matplotlib.pyplot as plt
import torchvision
class AutoEncodeNet(nn.Module):

    def __init__(self):
        super(AutoEncodeNet, self).__init__()
        # 編碼
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),   # 壓縮成3個特徵, 進行 3D 圖像可視化
        )
         # 解壓
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),       # 激勵函數讓輸出值在 (0, 1)
        )
        # 分類器
        self.classfier = nn.Sequential(
            nn.Linear(3,128),
            nn.Tanh(),
            nn.Linear(128,10),
            nn.Sigmoid(),
        )
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        lable = self.classfier(encoded)
        return encoded, decoded,lable

def train():
        # 超參數
    EPOCH = 20
    BATCH_SIZE = 64
    LR = 0.005
    DOWNLOAD_MNIST = False   # 下過數據的話, 就可以設置成 False
    N_TEST_IMG = 5          # 到時候顯示 5張圖片看效果, 如上圖一

    # Mnist digits dataset
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,                                     # this is training data
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=DOWNLOAD_MNIST,                        # download it if you don't have it
    )
    autoencoder = AutoEncodeNet()
    # autoencoder = torch.load("autoencoder_115.pkl")
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    # 編碼損失函數
    loss_func = nn.MSELoss()
    # 分類損失函數
    loss_func1 = nn.CrossEntropyLoss()
    # 數據加載
    train_loader = torch.utils.data.DataLoader(train_data,batch_size=128,shuffle=True)
    losses =[]
     


    fig,ax=plt.subplots(2,N_TEST_IMG)
    plt.ion()   # continuously plot
    # 會出驗證的五張原圖
    testImg = train_data.data[:5].view(-1,28,28).type(torch.FloatTensor)/255.
    for i in range(5):
        ax[0][i].imshow(testImg[i])

    for epoch in range(EPOCH):
        for step, (x, b_label) in enumerate(train_loader):
            b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
            b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)

            encoded, decoded ,lable= autoencoder(b_x)
            
            # 求損失
            loss = loss_func(decoded, b_y)  + loss_func1(lable,b_label)    # mean and onehot square error
            optimizer.zero_grad()               # clear gradients for this training step
            loss.backward()                     # backpropagation, compute gradients
            optimizer.step()                    # apply gradients
            # losses.append(loss.data.numpy())

            # plt.cla()
            
            # index=random.randint(100,110)
            # print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
            
            en,de,ll=autoencoder.forward(testImg.view(-1,28*28))
            dded = de.view(-1,28,28)
            for i in range(N_TEST_IMG):
                ax[1][i].clear()
                ax[1][i].imshow(dded[i].data.numpy())

                lll=list(ll.data[i])
                # print(lll)
                
                print(lll.index(max(lll)),end=" , ")
            print("------")
        
            plt.draw()
            plt.pause(0.01)
        
        print(loss)
        torch.save(autoencoder,"autoencoder_3"+epoch.__str__()+".pkl")
    plt.show()
    plt.ioff()
        
    
    
def test():
    # Mnist digits dataset
    DOWNLOAD_MNIST = False
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,                                     # this is training data
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=DOWNLOAD_MNIST,                        # download it if you don't have it
    )
    
    net = torch.load("autoencoder_8.pkl")
    index=random.randint(0,60000)
    print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
    en,de=net.forward(train_data.__getitem__(index)[0].view(-1,28*28))
    fig,[ax1,ax2]=plt.subplots(1,2)
    # ax2=plt.subplots(1,2)
    print(en)
    ax1.imshow(train_data.__getitem__(index)[0].view(28,28).data.numpy())
    ax2.imshow(de.view(28,28).data.numpy())
    plt.show()
if __name__ == "__main__":
    
    # test()
    train()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章