MNIST數據集--學習筆記

前言

昨天開始接觸卷積神經網絡,copy了一個網絡,其中用的數據集是MNIST數據集,我對於此很陌生,所以先查找了MNIST的相關資料,其中CSDN中有位博主寫的特別詳細,所以這邊也參考他的博客並加入自己的理解,這篇博客更多的是作爲個人的一個學習筆記。
參考博客:https://blog.csdn.net/simple_the_best/article/details/75267863

下載數據集

MNIST 數據集已經是一個被”嚼爛”了的數據集, 很多教程都會對它”下手”, 幾乎成爲一個 “典範”. 不過有些人可能對它還不是很瞭解, 下面來介紹一下.

MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取。嗯,把官方說明看了幾遍,大概對MNIST數據集有了個瞭解。
在這裏插入圖片描述

MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據.
可以新建一個文件夾 – mnist, 將數據集下載到 mnist 以後, 解壓即可。

還有另一種下載方法,可以通過pytorch庫安裝相應的數據集MNSIT

# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
    root='./mnist/',								# 下載到的地址
    train=True,                                     # 這裏表明是訓練數據
    download=True   				 			    # 設置download爲true,表示要下載該數據集

讀取到NumPy array 中

大家可以去看一下官方說明,其實和用匯編語言顯示字符類似,字符的前景色、背景色、閃爍等這些都記錄在屬性字節中。而這裏,下載的數據集圖片也是以字節的形式進行存儲, 所以我們需要把它們讀取到 NumPy array 中, 以便訓練和測試算法。

import os
import struct
import numpy as np

def load_mnist(path, kind='train'):
    """從地址 `path`中加載MNIST數據集"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

load_mnist 函數返回兩個數組, 第一個是一個 n x m 維的 NumPy array(images), 這裏的 n 是樣本數(行數), m 是特徵數(列數). 訓練數據集包含 60,000 個樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示. 在這裏, 我們將 28 x 28 的像素展開爲一個一維的行向量, 這些行向量就是圖片數組裏的行(每行 784 個值, 或者說每行就是代表了一張圖片). load_mnist 函數返回的第二個數組(labels) 包含了相應的目標變量, 也就是手寫數字的類標籤(整數 0-9)。
這裏可能會對讀取該數據集的方式有疑惑,

magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)

爲了理解這兩行代碼, 可以看一下 MNIST 網站上對數據集的介紹:
在這裏插入圖片描述
通過使用上面兩行代碼, 首先讀入 magic number, 它是一個文件協議的描述, 也是在調用 fromfile 方法將字節讀入 NumPy array 之前在文件緩衝中的 item 數(n). 作爲參數值傳入 struct.unpack>II 有兩個部分:

  • >:這是指大端(用來定義字節是如何存儲的)
  • I:這是指一個無符號整數.
    (這裏不太理解,先mark,參考博主的建議《深入理解計算機系統 – 2.1 節信息存儲》,再學習這塊知識)

可視化處理數據集MNIST

通過執行下面的代碼, 程序將會從剛剛解壓 MNIST 數據集後的 mnist 目錄下加載 60,000 個訓練樣本和 10,000 個測試樣本.

爲了瞭解 MNIST 中的圖片看起來到底是個啥, 對其進行可視化處理。從 feature matrix 中將 784-像素值 的向量 reshape 爲之前的 28*28 的形狀, 然後通過 matplotlib 的 imshow 函數進行繪製:

# refer to https://blog.csdn.net/simple_the_best/article/details/75267863
import matplotlib.pyplot as plt
import os
import struct
import numpy as np

def load_mnist(path, kind='train'):
    """從地址 `path`中加載MNIST數據集"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

X_train,y_train = load_mnist('mnist/MNIST/raw/') # 調用load_mnist函數,這裏用的是相對地址
fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(10):
    img = X_train[y_train == i][0].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

輸出

現在應該可以看到一個 2*5 的圖片, 裏面分別是 0-9 單個數字的圖片.

在這裏插入圖片描述

顯示不同樣本的數字

fig, ax = plt.subplots(
    nrows=5,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(25):
    img = X_train[y_train == 5][i].reshape(28, 28) 	# 可以顯示不同樣本圖片‘5’
    ax[i].imshow(img, cmap='Greys', interpolation='neare不st')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

還可以繪製某一數字的多個樣本圖片, 來看一下這些手寫樣本到底有多不同:
在這裏插入圖片描述在這裏插入圖片描述
另外, 該博主還有下載CSV版本的MNIST數據集,因爲我還沒實踐,所以在這就先不寫了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章