MNIST 數據集簡介

轉自:https://blog.csdn.net/simple_the_best/article/details/75267863

MNIST 數據集已經是一個被”嚼爛”了的數據集, 很多教程都會對它”下手”, 幾乎成爲一個 “典範”. 不過有些人可能對它還不是很瞭解, 下面來介紹一下.

MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)

MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據.

不妨新建一個文件夾 – mnist, 將數據集下載到 mnist 以後, 解壓即可:

dataset

圖片是以字節的形式進行存儲, 我們需要把它們讀取到 NumPy array 中, 以便訓練和測試算法.

import os
import struct
import numpy as np

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

load_mnist 函數返回兩個數組, 第一個是一個 n x m 維的 NumPy array(images), 這裏的 n 是樣本數(行數), m 是特徵數(列數). 訓練數據集包含 60,000 個樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示. 在這裏, 我們將 28 x 28 的像素展開爲一個一維的行向量, 這些行向量就是圖片數組裏的行(每行 784 個值, 或者說每行就是代表了一張圖片). load_mnist 函數返回的第二個數組(labels) 包含了相應的目標變量, 也就是手寫數字的類標籤(整數 0-9).

第一次見的話, 可能會覺得我們讀取圖片的方式有點奇怪:

magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)

爲了理解這兩行代碼, 我們先來看一下 MNIST 網站上對數據集的介紹:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

通過使用上面兩行代碼, 我們首先讀入 magic number, 它是一個文件協議的描述, 也是在我們調用 fromfile 方法將字節讀入 NumPy array 之前在文件緩衝中的 item 數(n). 作爲參數值傳入 struct.unpack 的 >II 有兩個部分:

  • >: 這是指大端(用來定義字節是如何存儲的); 如果你還不知道什麼是大端和小端, Endianness 是一個非常好的解釋. (關於大小端, 更多內容可見<<深入理解計算機系統 – 2.1 節信息存儲>>)
  • I: 這是指一個無符號整數.

通過執行下面的代碼, 我們將會從剛剛解壓 MNIST 數據集後的 mnist 目錄下加載 60,000 個訓練樣本和 10,000 個測試樣本.

爲了瞭解 MNIST 中的圖片看起來到底是個啥, 讓我們來對它們進行可視化處理. 從 feature matrix 中將 784-像素值 的向量 reshape 爲之前的 28*28 的形狀, 然後通過 matplotlib 的 imshow 函數進行繪製:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(10):
    img = X_train[y_train == i][0].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

我們現在應該可以看到一個 2*5 的圖片, 裏面分別是 0-9 單個數字的圖片.

0-9

此外, 我們還可以繪製某一數字的多個樣本圖片, 來看一下這些手寫樣本到底有多不同:

fig, ax = plt.subplots(
    nrows=5,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(25):
    img = X_train[y_train == 7][i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

執行上面的代碼後, 我們應該看到數字 7 的 25 個不同形態:

7

另外, 我們也可以選擇將 MNIST 圖片數據和標籤保存爲 CSV 文件, 這樣就可以在不支持特殊的字節格式的程序中打開數據集. 但是, 有一點要說明, CSV 的文件格式將會佔用更多的磁盤空間, 如下所示:

  • train_img.csv: 109.5 MB
  • train_labels.csv: 120 KB
  • test_img.csv: 18.3 MB
  • test_labels: 20 KB

如果我們打算保存這些 CSV 文件, 在將 MNIST 數據集加載入 NumPy array 以後, 我們應該執行下列代碼:

np.savetxt('train_img.csv', X_train,
           fmt='%i', delimiter=',')
np.savetxt('train_labels.csv', y_train,
           fmt='%i', delimiter=',')
np.savetxt('test_img.csv', X_test,
           fmt='%i', delimiter=',')
np.savetxt('test_labels.csv', y_test,
           fmt='%i', delimiter=',')

一旦將數據集保存爲 CSV 文件, 我們也可以用 NumPy 的 genfromtxt 函數重新將它們加載入程序中:

X_train = np.genfromtxt('train_img.csv',
                        dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',
                        dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',
                       dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',
                       dtype=int, delimiter=',')

不過, 從 CSV 文件中加載 MNIST 數據將會顯著發給更長的時間, 因此如果可能的話, 還是建議你維持數據集原有的字節格式.

參考: 
- Book , Python Machine Learning.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章