前言
昨天開始接觸卷積神經網絡,copy了一個網絡,其中用的數據集是MNIST數據集,我對於此很陌生,所以先查找了MNIST的相關資料,其中CSDN中有位博主寫的特別詳細,所以這邊也參考他的博客並加入自己的理解,這篇博客更多的是作爲個人的一個學習筆記。
參考博客:https://blog.csdn.net/simple_the_best/article/details/75267863
下載數據集
MNIST 數據集已經是一個被”嚼爛”了的數據集, 很多教程都會對它”下手”, 幾乎成爲一個 “典範”. 不過有些人可能對它還不是很瞭解, 下面來介紹一下.
MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取。嗯,把官方說明看了幾遍,大概對MNIST數據集有了個瞭解。
MNIST 數據集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字數據.
可以新建一個文件夾 – mnist, 將數據集下載到 mnist 以後, 解壓即可。
還有另一種下載方法,可以通過pytorch庫安裝相應的數據集MNSIT
# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
root='./mnist/', # 下載到的地址
train=True, # 這裏表明是訓練數據
download=True # 設置download爲true,表示要下載該數據集
讀取到NumPy array 中
大家可以去看一下官方說明,其實和用匯編語言顯示字符類似,字符的前景色、背景色、閃爍等這些都記錄在屬性字節中。而這裏,下載的數據集圖片也是以字節的形式進行存儲, 所以我們需要把它們讀取到 NumPy array 中, 以便訓練和測試算法。
import os
import struct
import numpy as np
def load_mnist(path, kind='train'):
"""從地址 `path`中加載MNIST數據集"""
labels_path = os.path.join(path,
'%s-labels-idx1-ubyte'
% kind)
images_path = os.path.join(path,
'%s-images-idx3-ubyte'
% kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',
lbpath.read(8))
labels = np.fromfile(lbpath,
dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',
imgpath.read(16))
images = np.fromfile(imgpath,
dtype=np.uint8).reshape(len(labels), 784)
return images, labels
load_mnist 函數返回兩個數組, 第一個是一個 n x m 維的 NumPy array(images), 這裏的 n 是樣本數(行數), m 是特徵數(列數). 訓練數據集包含 60,000 個樣本, 測試數據集包含 10,000 樣本. 在 MNIST 數據集中的每張圖片由 28 x 28 個像素點構成, 每個像素點用一個灰度值表示. 在這裏, 我們將 28 x 28 的像素展開爲一個一維的行向量, 這些行向量就是圖片數組裏的行(每行 784 個值, 或者說每行就是代表了一張圖片). load_mnist 函數返回的第二個數組(labels) 包含了相應的目標變量, 也就是手寫數字的類標籤(整數 0-9)。
這裏可能會對讀取該數據集的方式有疑惑,
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
爲了理解這兩行代碼, 可以看一下 MNIST 網站上對數據集的介紹:
通過使用上面兩行代碼, 首先讀入 magic number, 它是一個文件協議的描述, 也是在調用 fromfile
方法將字節讀入 NumPy array 之前在文件緩衝中的 item 數(n). 作爲參數值傳入 struct.unpack
的 >II
有兩個部分:
>
:這是指大端(用來定義字節是如何存儲的)I
:這是指一個無符號整數.
(這裏不太理解,先mark,參考博主的建議《深入理解計算機系統 – 2.1 節信息存儲》,再學習這塊知識)
可視化處理數據集MNIST
通過執行下面的代碼, 程序將會從剛剛解壓 MNIST 數據集後的 mnist 目錄下加載 60,000 個訓練樣本和 10,000 個測試樣本.
爲了瞭解 MNIST 中的圖片看起來到底是個啥, 對其進行可視化處理。從 feature matrix 中將 784-像素值 的向量 reshape 爲之前的 28*28 的形狀, 然後通過 matplotlib 的 imshow 函數進行繪製:
# refer to https://blog.csdn.net/simple_the_best/article/details/75267863
import matplotlib.pyplot as plt
import os
import struct
import numpy as np
def load_mnist(path, kind='train'):
"""從地址 `path`中加載MNIST數據集"""
labels_path = os.path.join(path,
'%s-labels-idx1-ubyte'
% kind)
images_path = os.path.join(path,
'%s-images-idx3-ubyte'
% kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',
lbpath.read(8))
labels = np.fromfile(lbpath,
dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',
imgpath.read(16))
images = np.fromfile(imgpath,
dtype=np.uint8).reshape(len(labels), 784)
return images, labels
X_train,y_train = load_mnist('mnist/MNIST/raw/') # 調用load_mnist函數,這裏用的是相對地址
fig, ax = plt.subplots(
nrows=2,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(10):
img = X_train[y_train == i][0].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
輸出
現在應該可以看到一個 2*5 的圖片, 裏面分別是 0-9 單個數字的圖片.
顯示不同樣本的數字
fig, ax = plt.subplots(
nrows=5,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(25):
img = X_train[y_train == 5][i].reshape(28, 28) # 可以顯示不同樣本圖片‘5’
ax[i].imshow(img, cmap='Greys', interpolation='neare不st')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
還可以繪製某一數字的多個樣本圖片, 來看一下這些手寫樣本到底有多不同:
另外, 該博主還有下載CSV版本的MNIST數據集,因爲我還沒實踐,所以在這就先不寫了。