PyTorch中使用遷移訓練(Transfer Learning)進行圖像分類

PyTorch使用方便,易於學習,開發效率很高。在這篇文章中,我們描述瞭如何在PyTorch中進行圖像分類。我們將使用CalTech256數據集的子集對10種不同種類動物的圖像進行分類。我們將介紹數據集準備,數據增強,然後逐步來構建分類器。預訓練模型ResNet50已經學習了低層次的圖像特徵,如邊緣、紋理等,我們使用遷移學習來複用這些低層次特徵,然後訓練我們的分類器來學習我們的數據集中圖像中的更高層次的細節,如眼睛、腿等。ResNet50已經使用ImageNet數據集的數百萬張圖像進行了訓練。 文章末尾有該實驗的完整的代碼,並且在文章中分析重要的片段,這樣讀者就可以理解它是如何工作的。

1、數據集準備

Caltech-256數據集有30607幅圖像,分爲256個不同的標記類和另一個“雜亂”類。 訓練整個數據集需要幾個小時,所以我們將研究包含10種動物的數據集的子集-熊、黑猩猩、長頸鹿、大猩猩、駱駝、鴕鳥、豪豬、臭鼬、三角龍和斑馬。這樣我們就能更快地進行實驗。當然,代碼也可以用來訓練整個數據集。 這些文件夾中的圖像數量從81(對於臭鼬)到212(對於大猩猩)不等。我們在每個類別中使用前60幅圖像進行訓練,使用後面的10幅圖像進行驗證,其餘的用於測試。 因此,最後我們有600個訓練圖像,100個驗證圖像,409個測試圖像和10類動物。

如果你想自己做這些實驗,請按照下面的步驟進行

1、下載CalTech256數據集,下載種子地址: http://shujujishi.com/dataset/4f9027a7-79c9-4612-9e02-044b679106fb 或者 https://hyper.ai/datasets/5261

2、創建三個目錄,名稱分別是train,valid和test。

3、在train,valid和test目錄中各創建10個子目錄。這些子目錄應該被命名爲bear, chimp, giraffe, gorilla, llama, ostrich, porcupine, skunk, triceratops and zebra

4、將Caltech256數據集中的bear的前60幅圖像移動到目錄train/bear,並對每隻動物重複這一點。

5、將Caltech256數據集中的bear的接下來10張圖像移動到目錄valid/bear,並對每隻動物重複這一點。

6、將Caltech256數據集中的bear剩下的圖片複製給目錄test/bear。每隻動物都重複這個。

最終結果如下:

caltech_10
├── test
│   ├── bear
│   │   ├── 009_0071.jpg
│   │   ├── 009_0072.jpg
│   │   ├── 009_0073.jpg
│   │   ├── 009_0074.jpg
│   │   ├── 009_0075.jpg
│   │   ├── 009_0076.jpg
│   │   ├── 009_0077.jpg
│   │   ├── 009_0078.jpg
│   │   ├── 009_0079.jpg
│   │   ├── 009_0080.jpg
│   │   ├── 009_0081.jpg
│   │   ├── 009_0082.jpg
│   │   ├── 009_0083.jpg
│   │   ├── 009_0084.jpg
│   │   ├── 009_0085.jpg
│   │   ├── 009_0086.jpg
│   │   ├── 009_0087.jpg
│   │   ├── 009_0088.jpg
│   │   ├── 009_0089.jpg
│   │   ├── 009_0090.jpg
│   │   ├── 009_0091.jpg
│   │   ├── 009_0092.jpg
│   │   ├── 009_0093.jpg
│   │   ├── 009_0094.jpg
│   │   ├── 009_0095.jpg
│   │   ├── 009_0096.jpg
│   │   ├── 009_0097.jpg
│   │   ├── 009_0098.jpg
│   │   ├── 009_0099.jpg
│   │   ├── 009_0100.jpg
│   │   ├── 009_0101.jpg
│   │   └── 009_0102.jpg
│   ├── chimp
│   │   ├── 038_0071.jpg
│   │   ├── 038_0072.jpg
│   │   ├── 038_0073.jpg
│   │   ├── 038_0074.jpg
│   │   ├── 038_0075.jpg
│   │   ├── 038_0076.jpg
│   │   ├── 038_0077.jpg
│   │   ├── 038_0078.jpg
│   │   ├── 038_0079.jpg
│   │   ├── 038_0080.jpg
│   │   ├── 038_0081.jpg
│   │   ├── 038_0082.jpg
│   │   ├── 038_0083.jpg
│   │   ├── 038_0084.jpg
│   │   ├── 038_0085.jpg
│   │   ├── 038_0086.jpg
│   │   ├── 038_0087.jpg
│   │   ├── 038_0088.jpg
│   │   ├── 038_0089.jpg
│   │   ├── 038_0090.jpg
│   │   ├── 038_0091.jpg
│   │   ├── 038_0092.jpg
│   │   ├── 038_0093.jpg
│   │   ├── 038_0094.jpg
│   │   ├── 038_0095.jpg
│   │   ├── 038_0096.jpg
│   │   ├── 038_0097.jpg
│   │   ├── 038_0098.jpg
│   │   ├── 038_0099.jpg
│   │   ├── 038_0100.jpg
│   │   ├── 038_0101.jpg
│   │   ├── 038_0102.jpg
│   │   ├── 038_0103.jpg
│   │   ├── 038_0104.jpg
│   │   ├── 038_0105.jpg
│   │   ├── 038_0106.jpg
│   │   ├── 038_0107.jpg
│   │   ├── 038_0108.jpg
│   │   ├── 038_0109.jpg
│   │   └── 038_0110.jpg
│   ├── giraffe
│   │   ├── 084_0071.jpg
│   │   ├── 084_0072.jpg
│   │   ├── 084_0073.jpg
│   │   ├── 084_0074.jpg
│   │   ├── 084_0075.jpg
│   │   ├── 084_0076.jpg
│   │   ├── 084_0077.jpg
│   │   ├── 084_0078.jpg
│   │   ├── 084_0079.jpg
│   │   ├── 084_0080.jpg
│   │   ├── 084_0081.jpg
│   │   ├── 084_0082.jpg
│   │   ├── 084_0083.jpg
│   │   └── 084_0084.jpg
│   ├── gorilla
│   │   ├── 090_0071.jpg
│   │   ├── 090_0072.jpg
│   │   ├── 090_0073.jpg
│   │   ├── 090_0074.jpg
│   │   ├── 090_0075.jpg
│   │   ├── 090_0076.jpg
│   │   ├── 090_0077.jpg
│   │   ├── 090_0078.jpg
│   │   ├── 090_0079.jpg
│   │   ├── 090_0080.jpg
│   │   ├── 090_0081.jpg
│   │   ├── 090_0082.jpg
│   │   ├── 090_0083.jpg
│   │   ├── 090_0084.jpg
│   │   ├── 090_0085.jpg
│   │   ├── 090_0086.jpg
│   │   ├── 090_0087.jpg
│   │   ├── 090_0088.jpg
│   │   ├── 090_0089.jpg
│   │   ├── 090_0090.jpg
│   │   ├── 090_0091.jpg
│   │   ├── 090_0092.jpg
│   │   ├── 090_0093.jpg
│   │   ├── 090_0094.jpg
│   │   ├── 090_0095.jpg
│   │   ├── 090_0096.jpg
│   │   ├── 090_0097.jpg
│   │   ├── 090_0098.jpg
│   │   ├── 090_0099.jpg
│   │   ├── 090_0100.jpg
│   │   ├── 090_0101.jpg
│   │   ├── 090_0102.jpg
│   │   ├── 090_0103.jpg
│   │   ├── 090_0104.jpg
│   │   ├── 090_0105.jpg
│   │   ├── 090_0106.jpg
│   │   ├── 090_0107.jpg
│   │   ├── 090_0108.jpg
│   │   ├── 090_0109.jpg
│   │   ├── 090_0110.jpg
│   │   ├── 090_0111.jpg
│   │   ├── 090_0112.jpg
│   │   ├── 090_0113.jpg
│   │   ├── 090_0114.jpg
│   │   ├── 090_0115.jpg
│   │   ├── 090_0116.jpg
│   │   ├── 090_0117.jpg
│   │   ├── 090_0118.jpg
│   │   ├── 090_0119.jpg
│   │   ├── 090_0120.jpg
│   │   ├── 090_0121.jpg
│   │   ├── 090_0122.jpg
│   │   ├── 090_0123.jpg
│   │   ├── 090_0124.jpg
│   │   ├── 090_0125.jpg
│   │   ├── 090_0126.jpg
│   │   ├── 090_0127.jpg
│   │   ├── 090_0128.jpg
│   │   ├── 090_0129.jpg
│   │   ├── 090_0130.jpg
│   │   ├── 090_0131.jpg
│   │   ├── 090_0132.jpg
│   │   ├── 090_0133.jpg
│   │   ├── 090_0134.jpg
│   │   ├── 090_0135.jpg
│   │   ├── 090_0136.jpg
│   │   ├── 090_0137.jpg
│   │   ├── 090_0138.jpg
│   │   ├── 090_0139.jpg
│   │   ├── 090_0140.jpg
│   │   ├── 090_0141.jpg
│   │   ├── 090_0142.jpg
│   │   ├── 090_0143.jpg
│   │   ├── 090_0144.jpg
│   │   ├── 090_0145.jpg
│   │   ├── 090_0146.jpg
│   │   ├── 090_0147.jpg
│   │   ├── 090_0148.jpg
│   │   ├── 090_0149.jpg
│   │   ├── 090_0150.jpg
│   │   ├── 090_0151.jpg
│   │   ├── 090_0152.jpg
│   │   ├── 090_0153.jpg
│   │   ├── 090_0154.jpg
│   │   ├── 090_0155.jpg
│   │   ├── 090_0156.jpg
│   │   ├── 090_0157.jpg
│   │   ├── 090_0158.jpg
│   │   ├── 090_0159.jpg
│   │   ├── 090_0160.jpg
│   │   ├── 090_0161.jpg
│   │   ├── 090_0162.jpg
│   │   ├── 090_0163.jpg
│   │   ├── 090_0164.jpg
│   │   ├── 090_0165.jpg
│   │   ├── 090_0166.jpg
│   │   ├── 090_0167.jpg
│   │   ├── 090_0168.jpg
│   │   ├── 090_0169.jpg
│   │   ├── 090_0170.jpg
│   │   ├── 090_0171.jpg
│   │   ├── 090_0172.jpg
│   │   ├── 090_0173.jpg
│   │   ├── 090_0174.jpg
│   │   ├── 090_0175.jpg
│   │   ├── 090_0176.jpg
│   │   ├── 090_0177.jpg
│   │   ├── 090_0178.jpg
│   │   ├── 090_0179.jpg
│   │   ├── 090_0180.jpg
│   │   ├── 090_0181.jpg
│   │   ├── 090_0182.jpg
│   │   ├── 090_0183.jpg
│   │   ├── 090_0184.jpg
│   │   ├── 090_0185.jpg
│   │   ├── 090_0186.jpg
│   │   ├── 090_0187.jpg
│   │   ├── 090_0188.jpg
│   │   ├── 090_0189.jpg
│   │   ├── 090_0190.jpg
│   │   ├── 090_0191.jpg
│   │   ├── 090_0192.jpg
│   │   ├── 090_0193.jpg
│   │   ├── 090_0194.jpg
│   │   ├── 090_0195.jpg
│   │   ├── 090_0196.jpg
│   │   ├── 090_0197.jpg
│   │   ├── 090_0198.jpg
│   │   ├── 090_0199.jpg
│   │   ├── 090_0200.jpg
│   │   ├── 090_0201.jpg
│   │   ├── 090_0202.jpg
│   │   ├── 090_0203.jpg
│   │   ├── 090_0204.jpg
│   │   ├── 090_0205.jpg
│   │   ├── 090_0206.jpg
│   │   ├── 090_0207.jpg
│   │   ├── 090_0208.jpg
│   │   ├── 090_0209.jpg
│   │   ├── 090_0210.jpg
│   │   ├── 090_0211.jpg
│   │   └── 090_0212.jpg
│   ├── llama
│   │   ├── 134_0071.jpg
│   │   ├── 134_0072.jpg
│   │   ├── 134_0073.jpg
│   │   ├── 134_0074.jpg
│   │   ├── 134_0075.jpg
│   │   ├── 134_0076.jpg
│   │   ├── 134_0077.jpg
│   │   ├── 134_0078.jpg
│   │   ├── 134_0079.jpg
│   │   ├── 134_0080.jpg
│   │   ├── 134_0081.jpg
│   │   ├── 134_0082.jpg
│   │   ├── 134_0083.jpg
│   │   ├── 134_0084.jpg
│   │   ├── 134_0085.jpg
│   │   ├── 134_0086.jpg
│   │   ├── 134_0087.jpg
│   │   ├── 134_0088.jpg
│   │   ├── 134_0089.jpg
│   │   ├── 134_0090.jpg
│   │   ├── 134_0091.jpg
│   │   ├── 134_0092.jpg
│   │   ├── 134_0093.jpg
│   │   ├── 134_0094.jpg
│   │   ├── 134_0095.jpg
│   │   ├── 134_0096.jpg
│   │   ├── 134_0097.jpg
│   │   ├── 134_0098.jpg
│   │   ├── 134_0099.jpg
│   │   ├── 134_0100.jpg
│   │   ├── 134_0101.jpg
│   │   ├── 134_0102.jpg
│   │   ├── 134_0103.jpg
│   │   ├── 134_0104.jpg
│   │   ├── 134_0105.jpg
│   │   ├── 134_0106.jpg
│   │   ├── 134_0107.jpg
│   │   ├── 134_0108.jpg
│   │   ├── 134_0109.jpg
│   │   ├── 134_0110.jpg
│   │   ├── 134_0111.jpg
│   │   ├── 134_0112.jpg
│   │   ├── 134_0113.jpg
│   │   ├── 134_0114.jpg
│   │   ├── 134_0115.jpg
│   │   ├── 134_0116.jpg
│   │   ├── 134_0117.jpg
│   │   ├── 134_0118.jpg
│   │   └── 134_0119.jpg
│   ├── ostrich
│   │   ├── 151_0071.jpg
│   │   ├── 151_0072.jpg
│   │   ├── 151_0073.jpg
│   │   ├── 151_0074.jpg
│   │   ├── 151_0075.jpg
│   │   ├── 151_0076.jpg
│   │   ├── 151_0077.jpg
│   │   ├── 151_0078.jpg
│   │   ├── 151_0079.jpg
│   │   ├── 151_0080.jpg
│   │   ├── 151_0081.jpg
│   │   ├── 151_0082.jpg
│   │   ├── 151_0083.jpg
│   │   ├── 151_0084.jpg
│   │   ├── 151_0085.jpg
│   │   ├── 151_0086.jpg
│   │   ├── 151_0087.jpg
│   │   ├── 151_0088.jpg
│   │   ├── 151_0089.jpg
│   │   ├── 151_0090.jpg
│   │   ├── 151_0091.jpg
│   │   ├── 151_0092.jpg
│   │   ├── 151_0093.jpg
│   │   ├── 151_0094.jpg
│   │   ├── 151_0095.jpg
│   │   ├── 151_0096.jpg
│   │   ├── 151_0097.jpg
│   │   ├── 151_0098.jpg
│   │   ├── 151_0099.jpg
│   │   ├── 151_0100.jpg
│   │   ├── 151_0101.jpg
│   │   ├── 151_0102.jpg
│   │   ├── 151_0103.jpg
│   │   ├── 151_0104.jpg
│   │   ├── 151_0105.jpg
│   │   ├── 151_0106.jpg
│   │   ├── 151_0107.jpg
│   │   ├── 151_0108.jpg
│   │   └── 151_0109.jpg
│   ├── porcupine
│   │   ├── 164_0071.jpg
│   │   ├── 164_0072.jpg
│   │   ├── 164_0073.jpg
│   │   ├── 164_0074.jpg
│   │   ├── 164_0075.jpg
│   │   ├── 164_0076.jpg
│   │   ├── 164_0077.jpg
│   │   ├── 164_0078.jpg
│   │   ├── 164_0079.jpg
│   │   ├── 164_0080.jpg
│   │   ├── 164_0081.jpg
│   │   ├── 164_0082.jpg
│   │   ├── 164_0083.jpg
│   │   ├── 164_0084.jpg
│   │   ├── 164_0085.jpg
│   │   ├── 164_0086.jpg
│   │   ├── 164_0087.jpg
│   │   ├── 164_0088.jpg
│   │   ├── 164_0089.jpg
│   │   ├── 164_0090.jpg
│   │   ├── 164_0091.jpg
│   │   ├── 164_0092.jpg
│   │   ├── 164_0093.jpg
│   │   ├── 164_0094.jpg
│   │   ├── 164_0095.jpg
│   │   ├── 164_0096.jpg
│   │   ├── 164_0097.jpg
│   │   ├── 164_0098.jpg
│   │   ├── 164_0099.jpg
│   │   ├── 164_0100.jpg
│   │   └── 164_0101.jpg
│   ├── skunk
│   │   ├── 186_0071.jpg
│   │   ├── 186_0072.jpg
│   │   ├── 186_0073.jpg
│   │   ├── 186_0074.jpg
│   │   ├── 186_0075.jpg
│   │   ├── 186_0076.jpg
│   │   ├── 186_0077.jpg
│   │   ├── 186_0078.jpg
│   │   ├── 186_0079.jpg
│   │   ├── 186_0080.jpg
│   │   └── 186_0081.jpg
│   ├── triceratops
│   │   ├── 228_0071.jpg
│   │   ├── 228_0072.jpg
│   │   ├── 228_0073.jpg
│   │   ├── 228_0074.jpg
│   │   ├── 228_0075.jpg
│   │   ├── 228_0076.jpg
│   │   ├── 228_0077.jpg
│   │   ├── 228_0078.jpg
│   │   ├── 228_0079.jpg
│   │   ├── 228_0080.jpg
│   │   ├── 228_0081.jpg
│   │   ├── 228_0082.jpg
│   │   ├── 228_0083.jpg
│   │   ├── 228_0084.jpg
│   │   ├── 228_0085.jpg
│   │   ├── 228_0086.jpg
│   │   ├── 228_0087.jpg
│   │   ├── 228_0088.jpg
│   │   ├── 228_0089.jpg
│   │   ├── 228_0090.jpg
│   │   ├── 228_0091.jpg
│   │   ├── 228_0092.jpg
│   │   ├── 228_0093.jpg
│   │   ├── 228_0094.jpg
│   │   └── 228_0095.jpg
│   └── zebra
│       ├── 250_0071.jpg
│       ├── 250_0072.jpg
│       ├── 250_0073.jpg
│       ├── 250_0074.jpg
│       ├── 250_0075.jpg
│       ├── 250_0076.jpg
│       ├── 250_0077.jpg
│       ├── 250_0078.jpg
│       ├── 250_0079.jpg
│       ├── 250_0080.jpg
│       ├── 250_0081.jpg
│       ├── 250_0082.jpg
│       ├── 250_0083.jpg
│       ├── 250_0084.jpg
│       ├── 250_0085.jpg
│       ├── 250_0086.jpg
│       ├── 250_0087.jpg
│       ├── 250_0088.jpg
│       ├── 250_0089.jpg
│       ├── 250_0090.jpg
│       ├── 250_0091.jpg
│       ├── 250_0092.jpg
│       ├── 250_0093.jpg
│       ├── 250_0094.jpg
│       ├── 250_0095.jpg
│       └── 250_0096.jpg
├── train
│   ├── bear
│   │   ├── 009_0001.jpg
│   │   ├── 009_0002.jpg
│   │   ├── 009_0003.jpg
│   │   ├── 009_0004.jpg
│   │   ├── 009_0005.jpg
│   │   ├── 009_0006.jpg
│   │   ├── 009_0007.jpg
│   │   ├── 009_0008.jpg
│   │   ├── 009_0009.jpg
│   │   ├── 009_0010.jpg
│   │   ├── 009_0011.jpg
│   │   ├── 009_0012.jpg
│   │   ├── 009_0013.jpg
│   │   ├── 009_0014.jpg
│   │   ├── 009_0015.jpg
│   │   ├── 009_0016.jpg
│   │   ├── 009_0017.jpg
│   │   ├── 009_0018.jpg
│   │   ├── 009_0019.jpg
│   │   ├── 009_0020.jpg
│   │   ├── 009_0021.jpg
│   │   ├── 009_0022.jpg
│   │   ├── 009_0023.jpg
│   │   ├── 009_0024.jpg
│   │   ├── 009_0025.jpg
│   │   ├── 009_0026.jpg
│   │   ├── 009_0027.jpg
│   │   ├── 009_0028.jpg
│   │   ├── 009_0029.jpg
│   │   ├── 009_0030.jpg
│   │   ├── 009_0031.jpg
│   │   ├── 009_0032.jpg
│   │   ├── 009_0033.jpg
│   │   ├── 009_0034.jpg
│   │   ├── 009_0035.jpg
│   │   ├── 009_0036.jpg
│   │   ├── 009_0037.jpg
│   │   ├── 009_0038.jpg
│   │   ├── 009_0039.jpg
│   │   ├── 009_0040.jpg
│   │   ├── 009_0041.jpg
│   │   ├── 009_0042.jpg
│   │   ├── 009_0043.jpg
│   │   ├── 009_0044.jpg
│   │   ├── 009_0045.jpg
│   │   ├── 009_0046.jpg
│   │   ├── 009_0047.jpg
│   │   ├── 009_0048.jpg
│   │   ├── 009_0049.jpg
│   │   ├── 009_0050.jpg
│   │   ├── 009_0051.jpg
│   │   ├── 009_0052.jpg
│   │   ├── 009_0053.jpg
│   │   ├── 009_0054.jpg
│   │   ├── 009_0055.jpg
│   │   ├── 009_0056.jpg
│   │   ├── 009_0057.jpg
│   │   ├── 009_0058.jpg
│   │   ├── 009_0059.jpg
│   │   └── 009_0060.jpg
│   ├── chimp
│   │   ├── 038_0001.jpg
│   │   ├── 038_0002.jpg
│   │   ├── 038_0003.jpg
│   │   ├── 038_0004.jpg
│   │   ├── 038_0005.jpg
│   │   ├── 038_0006.jpg
│   │   ├── 038_0007.jpg
│   │   ├── 038_0008.jpg
│   │   ├── 038_0009.jpg
│   │   ├── 038_0010.jpg
│   │   ├── 038_0011.jpg
│   │   ├── 038_0012.jpg
│   │   ├── 038_0013.jpg
│   │   ├── 038_0014.jpg
│   │   ├── 038_0015.jpg
│   │   ├── 038_0016.jpg
│   │   ├── 038_0017.jpg
│   │   ├── 038_0018.jpg
│   │   ├── 038_0019.jpg
│   │   ├── 038_0020.jpg
│   │   ├── 038_0021.jpg
│   │   ├── 038_0022.jpg
│   │   ├── 038_0023.jpg
│   │   ├── 038_0024.jpg
│   │   ├── 038_0025.jpg
│   │   ├── 038_0026.jpg
│   │   ├── 038_0027.jpg
│   │   ├── 038_0028.jpg
│   │   ├── 038_0029.jpg
│   │   ├── 038_0030.jpg
│   │   ├── 038_0031.jpg
│   │   ├── 038_0032.jpg
│   │   ├── 038_0033.jpg
│   │   ├── 038_0034.jpg
│   │   ├── 038_0035.jpg
│   │   ├── 038_0036.jpg
│   │   ├── 038_0037.jpg
│   │   ├── 038_0038.jpg
│   │   ├── 038_0039.jpg
│   │   ├── 038_0040.jpg
│   │   ├── 038_0041.jpg
│   │   ├── 038_0042.jpg
│   │   ├── 038_0043.jpg
│   │   ├── 038_0044.jpg
│   │   ├── 038_0045.jpg
│   │   ├── 038_0046.jpg
│   │   ├── 038_0047.jpg
│   │   ├── 038_0048.jpg
│   │   ├── 038_0049.jpg
│   │   ├── 038_0050.jpg
│   │   ├── 038_0051.jpg
│   │   ├── 038_0052.jpg
│   │   ├── 038_0053.jpg
│   │   ├── 038_0054.jpg
│   │   ├── 038_0055.jpg
│   │   ├── 038_0056.jpg
│   │   ├── 038_0057.jpg
│   │   ├── 038_0058.jpg
│   │   ├── 038_0059.jpg
│   │   └── 038_0060.jpg
│   ├── giraffe
│   │   ├── 084_0001.jpg
│   │   ├── 084_0002.jpg
│   │   ├── 084_0003.jpg
│   │   ├── 084_0004.jpg
│   │   ├── 084_0005.jpg
│   │   ├── 084_0006.jpg
│   │   ├── 084_0007.jpg
│   │   ├── 084_0008.jpg
│   │   ├── 084_0009.jpg
│   │   ├── 084_0010.jpg
│   │   ├── 084_0011.jpg
│   │   ├── 084_0012.jpg
│   │   ├── 084_0013.jpg
│   │   ├── 084_0014.jpg
│   │   ├── 084_0015.jpg
│   │   ├── 084_0016.jpg
│   │   ├── 084_0017.jpg
│   │   ├── 084_0018.jpg
│   │   ├── 084_0019.jpg
│   │   ├── 084_0020.jpg
│   │   ├── 084_0021.jpg
│   │   ├── 084_0022.jpg
│   │   ├── 084_0023.jpg
│   │   ├── 084_0024.jpg
│   │   ├── 084_0025.jpg
│   │   ├── 084_0026.jpg
│   │   ├── 084_0027.jpg
│   │   ├── 084_0028.jpg
│   │   ├── 084_0029.jpg
│   │   ├── 084_0030.jpg
│   │   ├── 084_0031.jpg
│   │   ├── 084_0032.jpg
│   │   ├── 084_0033.jpg
│   │   ├── 084_0034.jpg
│   │   ├── 084_0035.jpg
│   │   ├── 084_0036.jpg
│   │   ├── 084_0037.jpg
│   │   ├── 084_0038.jpg
│   │   ├── 084_0039.jpg
│   │   ├── 084_0040.jpg
│   │   ├── 084_0041.jpg
│   │   ├── 084_0042.jpg
│   │   ├── 084_0043.jpg
│   │   ├── 084_0044.jpg
│   │   ├── 084_0045.jpg
│   │   ├── 084_0046.jpg
│   │   ├── 084_0047.jpg
│   │   ├── 084_0048.jpg
│   │   ├── 084_0049.jpg
│   │   ├── 084_0050.jpg
│   │   ├── 084_0051.jpg
│   │   ├── 084_0052.jpg
│   │   ├── 084_0053.jpg
│   │   ├── 084_0054.jpg
│   │   ├── 084_0055.jpg
│   │   ├── 084_0056.jpg
│   │   ├── 084_0057.jpg
│   │   ├── 084_0058.jpg
│   │   ├── 084_0059.jpg
│   │   └── 084_0060.jpg
│   ├── gorilla
│   │   ├── 090_0001.jpg
│   │   ├── 090_0002.jpg
│   │   ├── 090_0003.jpg
│   │   ├── 090_0004.jpg
│   │   ├── 090_0005.jpg
│   │   ├── 090_0006.jpg
│   │   ├── 090_0007.jpg
│   │   ├── 090_0008.jpg
│   │   ├── 090_0009.jpg
│   │   ├── 090_0010.jpg
│   │   ├── 090_0011.jpg
│   │   ├── 090_0012.jpg
│   │   ├── 090_0013.jpg
│   │   ├── 090_0014.jpg
│   │   ├── 090_0015.jpg
│   │   ├── 090_0016.jpg
│   │   ├── 090_0017.jpg
│   │   ├── 090_0018.jpg
│   │   ├── 090_0019.jpg
│   │   ├── 090_0020.jpg
│   │   ├── 090_0021.jpg
│   │   ├── 090_0022.jpg
│   │   ├── 090_0023.jpg
│   │   ├── 090_0024.jpg
│   │   ├── 090_0025.jpg
│   │   ├── 090_0026.jpg
│   │   ├── 090_0027.jpg
│   │   ├── 090_0028.jpg
│   │   ├── 090_0029.jpg
│   │   ├── 090_0030.jpg
│   │   ├── 090_0031.jpg
│   │   ├── 090_0032.jpg
│   │   ├── 090_0033.jpg
│   │   ├── 090_0034.jpg
│   │   ├── 090_0035.jpg
│   │   ├── 090_0036.jpg
│   │   ├── 090_0037.jpg
│   │   ├── 090_0038.jpg
│   │   ├── 090_0039.jpg
│   │   ├── 090_0040.jpg
│   │   ├── 090_0041.jpg
│   │   ├── 090_0042.jpg
│   │   ├── 090_0043.jpg
│   │   ├── 090_0044.jpg
│   │   ├── 090_0045.jpg
│   │   ├── 090_0046.jpg
│   │   ├── 090_0047.jpg
│   │   ├── 090_0048.jpg
│   │   ├── 090_0049.jpg
│   │   ├── 090_0050.jpg
│   │   ├── 090_0051.jpg
│   │   ├── 090_0052.jpg
│   │   ├── 090_0053.jpg
│   │   ├── 090_0054.jpg
│   │   ├── 090_0055.jpg
│   │   ├── 090_0056.jpg
│   │   ├── 090_0057.jpg
│   │   ├── 090_0058.jpg
│   │   ├── 090_0059.jpg
│   │   └── 090_0060.jpg
│   ├── llama
│   │   ├── 134_0001.jpg
│   │   ├── 134_0002.jpg
│   │   ├── 134_0003.jpg
│   │   ├── 134_0004.jpg
│   │   ├── 134_0005.jpg
│   │   ├── 134_0006.jpg
│   │   ├── 134_0007.jpg
│   │   ├── 134_0008.jpg
│   │   ├── 134_0009.jpg
│   │   ├── 134_0010.jpg
│   │   ├── 134_0011.jpg
│   │   ├── 134_0012.jpg
│   │   ├── 134_0013.jpg
│   │   ├── 134_0014.jpg
│   │   ├── 134_0015.jpg
│   │   ├── 134_0016.jpg
│   │   ├── 134_0017.jpg
│   │   ├── 134_0018.jpg
│   │   ├── 134_0019.jpg
│   │   ├── 134_0020.jpg
│   │   ├── 134_0021.jpg
│   │   ├── 134_0022.jpg
│   │   ├── 134_0023.jpg
│   │   ├── 134_0024.jpg
│   │   ├── 134_0025.jpg
│   │   ├── 134_0026.jpg
│   │   ├── 134_0027.jpg
│   │   ├── 134_0028.jpg
│   │   ├── 134_0029.jpg
│   │   ├── 134_0030.jpg
│   │   ├── 134_0031.jpg
│   │   ├── 134_0032.jpg
│   │   ├── 134_0033.jpg
│   │   ├── 134_0034.jpg
│   │   ├── 134_0035.jpg
│   │   ├── 134_0036.jpg
│   │   ├── 134_0037.jpg
│   │   ├── 134_0038.jpg
│   │   ├── 134_0039.jpg
│   │   ├── 134_0040.jpg
│   │   ├── 134_0041.jpg
│   │   ├── 134_0042.jpg
│   │   ├── 134_0043.jpg
│   │   ├── 134_0044.jpg
│   │   ├── 134_0045.jpg
│   │   ├── 134_0046.jpg
│   │   ├── 134_0047.jpg
│   │   ├── 134_0048.jpg
│   │   ├── 134_0049.jpg
│   │   ├── 134_0050.jpg
│   │   ├── 134_0051.jpg
│   │   ├── 134_0052.jpg
│   │   ├── 134_0053.jpg
│   │   ├── 134_0054.jpg
│   │   ├── 134_0055.jpg
│   │   ├── 134_0056.jpg
│   │   ├── 134_0057.jpg
│   │   ├── 134_0058.jpg
│   │   ├── 134_0059.jpg
│   │   └── 134_0060.jpg
│   ├── ostrich
│   │   ├── 151_0001.jpg
│   │   ├── 151_0002.jpg
│   │   ├── 151_0003.jpg
│   │   ├── 151_0004.jpg
│   │   ├── 151_0005.jpg
│   │   ├── 151_0006.jpg
│   │   ├── 151_0007.jpg
│   │   ├── 151_0008.jpg
│   │   ├── 151_0009.jpg
│   │   ├── 151_0010.jpg
│   │   ├── 151_0011.jpg
│   │   ├── 151_0012.jpg
│   │   ├── 151_0013.jpg
│   │   ├── 151_0014.jpg
│   │   ├── 151_0015.jpg
│   │   ├── 151_0016.jpg
│   │   ├── 151_0017.jpg
│   │   ├── 151_0018.jpg
│   │   ├── 151_0019.jpg
│   │   ├── 151_0020.jpg
│   │   ├── 151_0021.jpg
│   │   ├── 151_0022.jpg
│   │   ├── 151_0023.jpg
│   │   ├── 151_0024.jpg
│   │   ├── 151_0025.jpg
│   │   ├── 151_0026.jpg
│   │   ├── 151_0027.jpg
│   │   ├── 151_0028.jpg
│   │   ├── 151_0029.jpg
│   │   ├── 151_0030.jpg
│   │   ├── 151_0031.jpg
│   │   ├── 151_0032.jpg
│   │   ├── 151_0033.jpg
│   │   ├── 151_0034.jpg
│   │   ├── 151_0035.jpg
│   │   ├── 151_0036.jpg
│   │   ├── 151_0037.jpg
│   │   ├── 151_0038.jpg
│   │   ├── 151_0039.jpg
│   │   ├── 151_0040.jpg
│   │   ├── 151_0041.jpg
│   │   ├── 151_0042.jpg
│   │   ├── 151_0043.jpg
│   │   ├── 151_0044.jpg
│   │   ├── 151_0045.jpg
│   │   ├── 151_0046.jpg
│   │   ├── 151_0047.jpg
│   │   ├── 151_0048.jpg
│   │   ├── 151_0049.jpg
│   │   ├── 151_0050.jpg
│   │   ├── 151_0051.jpg
│   │   ├── 151_0052.jpg
│   │   ├── 151_0053.jpg
│   │   ├── 151_0054.jpg
│   │   ├── 151_0055.jpg
│   │   ├── 151_0056.jpg
│   │   ├── 151_0057.jpg
│   │   ├── 151_0058.jpg
│   │   ├── 151_0059.jpg
│   │   └── 151_0060.jpg
│   ├── porcupine
│   │   ├── 164_0001.jpg
│   │   ├── 164_0002.jpg
│   │   ├── 164_0003.jpg
│   │   ├── 164_0004.jpg
│   │   ├── 164_0005.jpg
│   │   ├── 164_0006.jpg
│   │   ├── 164_0007.jpg
│   │   ├── 164_0008.jpg
│   │   ├── 164_0009.jpg
│   │   ├── 164_0010.jpg
│   │   ├── 164_0011.jpg
│   │   ├── 164_0012.jpg
│   │   ├── 164_0013.jpg
│   │   ├── 164_0014.jpg
│   │   ├── 164_0015.jpg
│   │   ├── 164_0016.jpg
│   │   ├── 164_0017.jpg
│   │   ├── 164_0018.jpg
│   │   ├── 164_0019.jpg
│   │   ├── 164_0020.jpg
│   │   ├── 164_0021.jpg
│   │   ├── 164_0022.jpg
│   │   ├── 164_0023.jpg
│   │   ├── 164_0024.jpg
│   │   ├── 164_0025.jpg
│   │   ├── 164_0026.jpg
│   │   ├── 164_0027.jpg
│   │   ├── 164_0028.jpg
│   │   ├── 164_0029.jpg
│   │   ├── 164_0030.jpg
│   │   ├── 164_0031.jpg
│   │   ├── 164_0032.jpg
│   │   ├── 164_0033.jpg
│   │   ├── 164_0034.jpg
│   │   ├── 164_0035.jpg
│   │   ├── 164_0036.jpg
│   │   ├── 164_0037.jpg
│   │   ├── 164_0038.jpg
│   │   ├── 164_0039.jpg
│   │   ├── 164_0040.jpg
│   │   ├── 164_0041.jpg
│   │   ├── 164_0042.jpg
│   │   ├── 164_0043.jpg
│   │   ├── 164_0044.jpg
│   │   ├── 164_0045.jpg
│   │   ├── 164_0046.jpg
│   │   ├── 164_0047.jpg
│   │   ├── 164_0048.jpg
│   │   ├── 164_0049.jpg
│   │   ├── 164_0050.jpg
│   │   ├── 164_0051.jpg
│   │   ├── 164_0052.jpg
│   │   ├── 164_0053.jpg
│   │   ├── 164_0054.jpg
│   │   ├── 164_0055.jpg
│   │   ├── 164_0056.jpg
│   │   ├── 164_0057.jpg
│   │   ├── 164_0058.jpg
│   │   ├── 164_0059.jpg
│   │   └── 164_0060.jpg
│   ├── skunk
│   │   ├── 186_0001.jpg
│   │   ├── 186_0002.jpg
│   │   ├── 186_0003.jpg
│   │   ├── 186_0004.jpg
│   │   ├── 186_0005.jpg
│   │   ├── 186_0006.jpg
│   │   ├── 186_0007.jpg
│   │   ├── 186_0008.jpg
│   │   ├── 186_0009.jpg
│   │   ├── 186_0010.jpg
│   │   ├── 186_0011.jpg
│   │   ├── 186_0012.jpg
│   │   ├── 186_0013.jpg
│   │   ├── 186_0014.jpg
│   │   ├── 186_0015.jpg
│   │   ├── 186_0016.jpg
│   │   ├── 186_0017.jpg
│   │   ├── 186_0018.jpg
│   │   ├── 186_0019.jpg
│   │   ├── 186_0020.jpg
│   │   ├── 186_0021.jpg
│   │   ├── 186_0022.jpg
│   │   ├── 186_0023.jpg
│   │   ├── 186_0024.jpg
│   │   ├── 186_0025.jpg
│   │   ├── 186_0026.jpg
│   │   ├── 186_0027.jpg
│   │   ├── 186_0028.jpg
│   │   ├── 186_0029.jpg
│   │   ├── 186_0030.jpg
│   │   ├── 186_0031.jpg
│   │   ├── 186_0032.jpg
│   │   ├── 186_0033.jpg
│   │   ├── 186_0034.jpg
│   │   ├── 186_0035.jpg
│   │   ├── 186_0036.jpg
│   │   ├── 186_0037.jpg
│   │   ├── 186_0038.jpg
│   │   ├── 186_0039.jpg
│   │   ├── 186_0040.jpg
│   │   ├── 186_0041.jpg
│   │   ├── 186_0042.jpg
│   │   ├── 186_0043.jpg
│   │   ├── 186_0044.jpg
│   │   ├── 186_0045.jpg
│   │   ├── 186_0046.jpg
│   │   ├── 186_0047.jpg
│   │   ├── 186_0048.jpg
│   │   ├── 186_0049.jpg
│   │   ├── 186_0050.jpg
│   │   ├── 186_0051.jpg
│   │   ├── 186_0052.jpg
│   │   ├── 186_0053.jpg
│   │   ├── 186_0054.jpg
│   │   ├── 186_0055.jpg
│   │   ├── 186_0056.jpg
│   │   ├── 186_0057.jpg
│   │   ├── 186_0058.jpg
│   │   ├── 186_0059.jpg
│   │   └── 186_0060.jpg
│   ├── triceratops
│   │   ├── 228_0001.jpg
│   │   ├── 228_0002.jpg
│   │   ├── 228_0003.jpg
│   │   ├── 228_0004.jpg
│   │   ├── 228_0005.jpg
│   │   ├── 228_0006.jpg
│   │   ├── 228_0007.jpg
│   │   ├── 228_0008.jpg
│   │   ├── 228_0009.jpg
│   │   ├── 228_0010.jpg
│   │   ├── 228_0011.jpg
│   │   ├── 228_0012.jpg
│   │   ├── 228_0013.jpg
│   │   ├── 228_0014.jpg
│   │   ├── 228_0015.jpg
│   │   ├── 228_0016.jpg
│   │   ├── 228_0017.jpg
│   │   ├── 228_0018.jpg
│   │   ├── 228_0019.jpg
│   │   ├── 228_0020.jpg
│   │   ├── 228_0021.jpg
│   │   ├── 228_0022.jpg
│   │   ├── 228_0023.jpg
│   │   ├── 228_0024.jpg
│   │   ├── 228_0025.jpg
│   │   ├── 228_0026.jpg
│   │   ├── 228_0027.jpg
│   │   ├── 228_0028.jpg
│   │   ├── 228_0029.jpg
│   │   ├── 228_0030.jpg
│   │   ├── 228_0031.jpg
│   │   ├── 228_0032.jpg
│   │   ├── 228_0033.jpg
│   │   ├── 228_0034.jpg
│   │   ├── 228_0035.jpg
│   │   ├── 228_0036.jpg
│   │   ├── 228_0037.jpg
│   │   ├── 228_0038.jpg
│   │   ├── 228_0039.jpg
│   │   ├── 228_0040.jpg
│   │   ├── 228_0041.jpg
│   │   ├── 228_0042.jpg
│   │   ├── 228_0043.jpg
│   │   ├── 228_0044.jpg
│   │   ├── 228_0045.jpg
│   │   ├── 228_0046.jpg
│   │   ├── 228_0047.jpg
│   │   ├── 228_0048.jpg
│   │   ├── 228_0049.jpg
│   │   ├── 228_0050.jpg
│   │   ├── 228_0051.jpg
│   │   ├── 228_0052.jpg
│   │   ├── 228_0053.jpg
│   │   ├── 228_0054.jpg
│   │   ├── 228_0055.jpg
│   │   ├── 228_0056.jpg
│   │   ├── 228_0057.jpg
│   │   ├── 228_0058.jpg
│   │   ├── 228_0059.jpg
│   │   └── 228_0060.jpg
│   └── zebra
│       ├── 250_0001.jpg
│       ├── 250_0002.jpg
│       ├── 250_0003.jpg
│       ├── 250_0004.jpg
│       ├── 250_0005.jpg
│       ├── 250_0006.jpg
│       ├── 250_0007.jpg
│       ├── 250_0008.jpg
│       ├── 250_0009.jpg
│       ├── 250_0010.jpg
│       ├── 250_0011.jpg
│       ├── 250_0012.jpg
│       ├── 250_0013.jpg
│       ├── 250_0014.jpg
│       ├── 250_0015.jpg
│       ├── 250_0016.jpg
│       ├── 250_0017.jpg
│       ├── 250_0018.jpg
│       ├── 250_0019.jpg
│       ├── 250_0020.jpg
│       ├── 250_0021.jpg
│       ├── 250_0022.jpg
│       ├── 250_0023.jpg
│       ├── 250_0024.jpg
│       ├── 250_0025.jpg
│       ├── 250_0026.jpg
│       ├── 250_0027.jpg
│       ├── 250_0028.jpg
│       ├── 250_0029.jpg
│       ├── 250_0030.jpg
│       ├── 250_0031.jpg
│       ├── 250_0032.jpg
│       ├── 250_0033.jpg
│       ├── 250_0034.jpg
│       ├── 250_0035.jpg
│       ├── 250_0036.jpg
│       ├── 250_0037.jpg
│       ├── 250_0038.jpg
│       ├── 250_0039.jpg
│       ├── 250_0040.jpg
│       ├── 250_0041.jpg
│       ├── 250_0042.jpg
│       ├── 250_0043.jpg
│       ├── 250_0044.jpg
│       ├── 250_0045.jpg
│       ├── 250_0046.jpg
│       ├── 250_0047.jpg
│       ├── 250_0048.jpg
│       ├── 250_0049.jpg
│       ├── 250_0050.jpg
│       ├── 250_0051.jpg
│       ├── 250_0052.jpg
│       ├── 250_0053.jpg
│       ├── 250_0054.jpg
│       ├── 250_0055.jpg
│       ├── 250_0056.jpg
│       ├── 250_0057.jpg
│       ├── 250_0058.jpg
│       ├── 250_0059.jpg
│       └── 250_0060.jpg
└── valid
    ├── bear
    │   ├── 009_0061.jpg
    │   ├── 009_0062.jpg
    │   ├── 009_0063.jpg
    │   ├── 009_0064.jpg
    │   ├── 009_0065.jpg
    │   ├── 009_0066.jpg
    │   ├── 009_0067.jpg
    │   ├── 009_0068.jpg
    │   ├── 009_0069.jpg
    │   └── 009_0070.jpg
    ├── chimp
    │   ├── 038_0061.jpg
    │   ├── 038_0062.jpg
    │   ├── 038_0063.jpg
    │   ├── 038_0064.jpg
    │   ├── 038_0065.jpg
    │   ├── 038_0066.jpg
    │   ├── 038_0067.jpg
    │   ├── 038_0068.jpg
    │   ├── 038_0069.jpg
    │   └── 038_0070.jpg
    ├── giraffe
    │   ├── 084_0061.jpg
    │   ├── 084_0062.jpg
    │   ├── 084_0063.jpg
    │   ├── 084_0064.jpg
    │   ├── 084_0065.jpg
    │   ├── 084_0066.jpg
    │   ├── 084_0067.jpg
    │   ├── 084_0068.jpg
    │   ├── 084_0069.jpg
    │   └── 084_0070.jpg
    ├── gorilla
    │   ├── 090_0061.jpg
    │   ├── 090_0062.jpg
    │   ├── 090_0063.jpg
    │   ├── 090_0064.jpg
    │   ├── 090_0065.jpg
    │   ├── 090_0066.jpg
    │   ├── 090_0067.jpg
    │   ├── 090_0068.jpg
    │   ├── 090_0069.jpg
    │   └── 090_0070.jpg
    ├── llama
    │   ├── 134_0061.jpg
    │   ├── 134_0062.jpg
    │   ├── 134_0063.jpg
    │   ├── 134_0064.jpg
    │   ├── 134_0065.jpg
    │   ├── 134_0066.jpg
    │   ├── 134_0067.jpg
    │   ├── 134_0068.jpg
    │   ├── 134_0069.jpg
    │   └── 134_0070.jpg
    ├── ostrich
    │   ├── 151_0061.jpg
    │   ├── 151_0062.jpg
    │   ├── 151_0063.jpg
    │   ├── 151_0064.jpg
    │   ├── 151_0065.jpg
    │   ├── 151_0066.jpg
    │   ├── 151_0067.jpg
    │   ├── 151_0068.jpg
    │   ├── 151_0069.jpg
    │   └── 151_0070.jpg
    ├── porcupine
    │   ├── 164_0061.jpg
    │   ├── 164_0062.jpg
    │   ├── 164_0063.jpg
    │   ├── 164_0064.jpg
    │   ├── 164_0065.jpg
    │   ├── 164_0066.jpg
    │   ├── 164_0067.jpg
    │   ├── 164_0068.jpg
    │   ├── 164_0069.jpg
    │   └── 164_0070.jpg
    ├── skunk
    │   ├── 186_0061.jpg
    │   ├── 186_0062.jpg
    │   ├── 186_0063.jpg
    │   ├── 186_0064.jpg
    │   ├── 186_0065.jpg
    │   ├── 186_0066.jpg
    │   ├── 186_0067.jpg
    │   ├── 186_0068.jpg
    │   ├── 186_0069.jpg
    │   └── 186_0070.jpg
    ├── triceratops
    │   ├── 228_0061.jpg
    │   ├── 228_0062.jpg
    │   ├── 228_0063.jpg
    │   ├── 228_0064.jpg
    │   ├── 228_0065.jpg
    │   ├── 228_0066.jpg
    │   ├── 228_0067.jpg
    │   ├── 228_0068.jpg
    │   ├── 228_0069.jpg
    │   └── 228_0070.jpg
    └── zebra
        ├── 250_0061.jpg
        ├── 250_0062.jpg
        ├── 250_0063.jpg
        ├── 250_0064.jpg
        ├── 250_0065.jpg
        ├── 250_0066.jpg
        ├── 250_0067.jpg
        ├── 250_0068.jpg
        ├── 250_0069.jpg
        └── 250_0070.jpg

2、數據擴充

可用的訓練集中的圖像可以通過多種方式進行修改,以在訓練過程中包含更多的變種圖像,從而使訓練後的模型更具通用性,也就是防止過擬合,並在不同類型的測試數據上表現良好。此外,輸入數據可以有各種大小。它們需要標準化爲一個固定的大小和格式,然後組織成批量數據一起用於訓練。 每個輸入圖像首先通過若干變換。我們試圖通過在變換中引入一些隨機性來使圖像發生變化。在每個輪次,每個圖像都應用一組變換。當我們重複訓練多個輪次時,模型會看到更多的輸入圖像的變化,每個訓練輪次中,轉換都有一個新的隨機變化。這導致數據增強,然後模型變得更加通用。 下面我們看到三角龍(Triceratops)圖像的轉換版本的例子。

Data Augmentation

讓我們解析一下用於數據增強擴張的轉換。

RandomResizedCrop以隨機大小(在原始大小的0.8到1.0的尺度範圍內,在默認範圍爲0.75到1.33的隨機縱橫比範圍內)對輸入圖像進行裁剪。然後將圖像尺寸調整爲256×256。

RandomRotation以隨機選擇的角度旋轉圖像(在-15到15度之間)。

隨機進行圖像水平翻轉,默認概率爲50%。

CenterCrop從中心扣出一個224×224圖像。

ToTensor 將PIL圖像的值(0-255)轉換爲浮點張量,並通過除以255將它們歸一化爲0-1。

正則化採用3通道張量,並通過通道的輸入均值和標準差對每個通道進行正則化。均值和標準差向量作爲3個元素向量輸入。張量中的每個通道正則化化爲T=(T-均值)/(標準差) 。所有上述轉換都使用Compose鏈接在一起。

import torch, torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

# Applying Transforms to the Data
image_transforms = { 
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}

請注意,對於驗證和測試數據,我們不進行RandomResizedCropRandomRotation 和 RandomHorizontalFlip轉換。我們只需將驗證圖像和測試圖像調整到256×256,並裁剪出中心224×224部分,以便能夠使模型正確使用它們。然後將圖像轉換爲張量,並使用ImageNet中所有圖像的均值和標準差來進行正則化。

3、數據加載

接下來,讓我們看看如何使用上述定義的轉換並加載用於訓練的數據。

# Load the Data

# Set train and valid directory paths

dataset = 'caltech_10'

train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'valid')
test_directory = os.path.join(dataset, 'test')

# Batch size
bs = 32

# Number of classes
num_classes = len(os.listdir(valid_directory))  #10
print(num_classes)

# Load Data from folders
data = {
    'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']),
    'test': datasets.ImageFolder(root=test_directory, transform=image_transforms['test'])
}

# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in data['train'].class_to_idx.items()}
print(idx_to_class)

# Size of Data, to be used for calculating Average Loss and Accuracy
train_data_size = len(data['train'])
valid_data_size = len(data['valid'])
test_data_size = len(data['test'])

# Create iterators for the Data loaded using DataLoader module
train_data_loader = DataLoader(data['train'], batch_size=bs, shuffle=True)
valid_data_loader = DataLoader(data['valid'], batch_size=bs, shuffle=True)
test_data_loader = DataLoader(data['test'], batch_size=bs, shuffle=True)

我們首先設置訓練、驗證和測試數據目錄,以及批處理大小(32)。然後我們使用DataLoader加載它們。請注意,我們使用DataLoader加載圖像時會進行前面指定的圖像轉換。數據的順序也被打亂重新排序。torchvision.transforms包和DataLoader是非常重要的PyTorch功能,使數據增強擴張和加載過程非常容易。

4、遷移學習

收集感興趣領域的圖像並從頭訓練分類器是非常困難和耗時的。因此,我們使用預先訓練的模型作爲我們的基礎,並改變最後幾層,以便我們可以根據想要的類對圖像進行分類。即使使用一個小的數據集,我們也能獲得一個良好的模型。因爲基本的圖像特徵已經在預先訓練的模型被獲取了,而且是從一個更大的數據集中獲取的,如ImageNet數據集。

Transfer Learning for Image Classification

正如我們在上面的圖像中所看到的,預先訓練的模型的內層保持不變,只有最後幾層被改變以適應我們的類數。在本工作中,我們使用預先訓練的ResNet50模型。

# Load pretrained ResNet50 Model
resnet50 = models.resnet50(pretrained=True)
# resnet50 = resnet50.to('cuda:0')

ResNet50是在準確性和推理時間之間有很好的權衡的一種模型。當模型在Py Torch中加載時,其所有參數的“requid_grad”字段默認設置爲true。這意味着參數值的每一個變化都將被存儲在用於訓練的反向傳播圖中。這增加了內存需求。因此,由於我們預先訓練的模型中的大多數參數已經訓練好了,我們將require_grad字段重置爲false。

# Freeze model parameters
for param in resnet50.parameters():
    param.requires_grad = False

然後,我們將ResNet50模型的最後一層替換爲一組小的順序層。對ResNet50的最後一個完全連接層的輸入被饋送到具有256個輸出的線性層,然後這些輸出被饋入ReLU激活函數和Dropout層。然後是一個256×10線性層,它有10個輸出對應於我們的CalTech子集中的10個類。NLLLoss和softmax往往配合使用,參考:https://blog.csdn.net/jasonleesjtu/article/details/89097758

# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features
 
resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, 10), 
    nn.LogSoftmax(dim=1) # For using NLLLoss()
# Convert model to be used on GPU
# resnet50 = resnet50.to('cuda:0')

如果你的實驗環境支持GPU,可以將代碼中有關gpu的部分去掉,便可以在gpu上進行訓練。

接下來,我們定義了損失函數和用於訓練的優化器。PyTorch提供多種損失函數。我們使用負對數似然函數,因爲它可以用來分類多個類。PyTorch還支持多個優化器。我們用Adam優化器。Adam是最受歡迎的優化器之一,因爲它可以單獨調整每個參數的學習速率。

# Define Optimizer and Loss Function
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())

5、訓練

訓練過程按輪次進行,重複進行指定的次數,每張圖像在一個輪次中處理一次。訓練數據加載器批量加載數據,本例中,我們指定批次大小爲32,這意味着每個批次最多可以有32個圖像。通常,一個輪次包含多個批次。 對於每一批,輸入圖像數據通過模型,即向前傳遞,以獲得輸出。然後使用所提供的損失函數使用真實類型和輸出結果類型來計算損失。參考計算出的損失,利用後向函數可訓練參數的梯度。請注意,在轉移學習中,我們只需要計算一小組參數的梯度,這些參數屬於模型末尾的幾個新添加的層。對模型進行彙總,可以揭示實際參數個數和可訓練參數個數,正如我們在下面看到的,我們現在只需要訓練大約十分之一的模型參數。

Parameter count summary

梯度計算是使用自動梯度和反向傳播,在圖中微分使用鏈式規則。PyTorch在向後傳播中積累了所有的梯度。因此,在訓練循環開始時,必須將它們清零。這是使用優化器的zero_grad函數實現的。最後,在向後傳遞中計算梯度後,使用優化器的step函數更新參數。 計算每個批次的總損失和準確度,然後對所有批次進行平均,得到整個輪次的損失和準確度值。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_and_validate(model, loss_criterion, optimizer, epochs=25):
    '''
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param loss_criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)
  
    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    '''
    
    start = time.time()
    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_data_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            
            #print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))

            
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_data_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

                #print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))
            
        # Find average training loss and training accuracy
        avg_train_loss = train_loss/train_data_size 
        avg_train_acc = train_acc/train_data_size

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/valid_data_size 
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
        epoch_end = time.time()
    
        print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))
        
        # Save if the model has best accuracy till now
        torch.save(model, dataset+'_model_'+str(epoch)+'.pt')
            
    return model, history
# Print the model to be trained
# summary(resnet50, input_size=(3, 224, 224), batch_size=bs, device='cuda')

# Train the model for 30 epochs
num_epochs = 30
trained_model, history = train_and_validate(resnet50, loss_func, optimizer, num_epochs)

torch.save(history, dataset+'_history.pt')

train_loss += loss.item() * inputs.size(0)中,iputs.size(0)返回32,即每批次有32張圖片,最後一批次可能不到32張。

ret, predictions = torch.max(outputs.data, 1)
correct_counts = predictions.eq(labels.data.view_as(predictions)) 中,predictions是一維列表,順序包含每張圖片的預測類型標號,correct_counts爲一維列表,只包含0或者1,0代表預測正確,1代表預測錯誤。

history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc]),history爲二維數組,包含每輪次的訓練統計結果,爲畫圖標準備數據。

本例的實驗環境爲ubuntu18.04,使用了i5中的4核,沒有使用gpu。輸出結果如下:

Epoch: 1/30
Epoch : 000, Training: Loss: 1.5882, Accuracy: 48.5000%, 
		Validation : Loss : 0.6234, Accuracy: 86.0000%, Time: 496.7489s
Epoch: 2/30
Epoch : 001, Training: Loss: 0.6274, Accuracy: 84.8333%, 
		Validation : Loss : 0.2899, Accuracy: 96.0000%, Time: 143.8269s
Epoch: 3/30
Epoch : 002, Training: Loss: 0.3289, Accuracy: 93.6667%, 
		Validation : Loss : 0.2095, Accuracy: 97.0000%, Time: 114.8437s
Epoch: 4/30
Epoch : 003, Training: Loss: 0.2767, Accuracy: 92.3333%, 
		Validation : Loss : 0.2141, Accuracy: 94.0000%, Time: 115.5789s
Epoch: 5/30
Epoch : 004, Training: Loss: 0.2121, Accuracy: 94.0000%, 
		Validation : Loss : 0.1396, Accuracy: 96.0000%, Time: 104.1761s
Epoch: 6/30
Epoch : 005, Training: Loss: 0.2227, Accuracy: 93.6667%, 
		Validation : Loss : 0.1570, Accuracy: 96.0000%, Time: 104.7612s
Epoch: 7/30
Epoch : 006, Training: Loss: 0.1922, Accuracy: 93.8333%, 
		Validation : Loss : 0.1750, Accuracy: 96.0000%, Time: 116.5069s
Epoch: 8/30
Epoch : 007, Training: Loss: 0.1541, Accuracy: 95.8333%, 
		Validation : Loss : 0.1555, Accuracy: 96.0000%, Time: 103.5355s
Epoch: 9/30
Epoch : 008, Training: Loss: 0.1387, Accuracy: 95.8333%, 
		Validation : Loss : 0.1756, Accuracy: 94.0000%, Time: 106.4436s
Epoch: 10/30
Epoch : 009, Training: Loss: 0.1400, Accuracy: 96.6667%, 
		Validation : Loss : 0.1248, Accuracy: 97.0000%, Time: 102.8343s
Epoch: 11/30
Epoch : 010, Training: Loss: 0.1331, Accuracy: 96.1667%, 
		Validation : Loss : 0.1316, Accuracy: 97.0000%, Time: 102.7859s
Epoch: 12/30
Epoch : 011, Training: Loss: 0.0964, Accuracy: 97.1667%, 
		Validation : Loss : 0.1461, Accuracy: 96.0000%, Time: 101.7347s
Epoch: 13/30
Epoch : 012, Training: Loss: 0.0850, Accuracy: 98.0000%, 
		Validation : Loss : 0.1296, Accuracy: 97.0000%, Time: 103.7292s
Epoch: 14/30
Epoch : 013, Training: Loss: 0.1283, Accuracy: 95.1667%, 
		Validation : Loss : 0.1328, Accuracy: 96.0000%, Time: 101.8993s
Epoch: 15/30
Epoch : 014, Training: Loss: 0.1033, Accuracy: 96.8333%, 
		Validation : Loss : 0.2060, Accuracy: 93.0000%, Time: 101.6368s
Epoch: 16/30
Epoch : 015, Training: Loss: 0.0890, Accuracy: 97.6667%, 
		Validation : Loss : 0.1072, Accuracy: 98.0000%, Time: 101.3270s
Epoch: 17/30
Epoch : 016, Training: Loss: 0.0949, Accuracy: 97.3333%, 
		Validation : Loss : 0.1981, Accuracy: 92.0000%, Time: 102.3606s
Epoch: 18/30
Epoch : 017, Training: Loss: 0.1004, Accuracy: 96.1667%, 
		Validation : Loss : 0.1445, Accuracy: 93.0000%, Time: 102.5671s
Epoch: 19/30
Epoch : 018, Training: Loss: 0.1029, Accuracy: 96.8333%, 
		Validation : Loss : 0.1777, Accuracy: 93.0000%, Time: 102.3143s
Epoch: 20/30
Epoch : 019, Training: Loss: 0.0813, Accuracy: 97.1667%, 
		Validation : Loss : 0.1403, Accuracy: 95.0000%, Time: 102.7968s
Epoch: 21/30
Epoch : 020, Training: Loss: 0.0581, Accuracy: 98.5000%, 
		Validation : Loss : 0.1375, Accuracy: 95.0000%, Time: 101.5044s
Epoch: 22/30
Epoch : 021, Training: Loss: 0.0636, Accuracy: 98.3333%, 
		Validation : Loss : 0.1043, Accuracy: 96.0000%, Time: 101.3938s
Epoch: 23/30
Epoch : 022, Training: Loss: 0.0822, Accuracy: 97.5000%, 
		Validation : Loss : 0.1587, Accuracy: 95.0000%, Time: 102.4387s
Epoch: 24/30
Epoch : 023, Training: Loss: 0.0650, Accuracy: 97.6667%, 
		Validation : Loss : 0.1398, Accuracy: 95.0000%, Time: 100.7330s
Epoch: 25/30
Epoch : 024, Training: Loss: 0.0852, Accuracy: 97.0000%, 
		Validation : Loss : 0.1054, Accuracy: 96.0000%, Time: 101.2396s
Epoch: 26/30
Epoch : 025, Training: Loss: 0.0894, Accuracy: 97.3333%, 
		Validation : Loss : 0.1208, Accuracy: 97.0000%, Time: 101.9070s
Epoch: 27/30
Epoch : 026, Training: Loss: 0.0640, Accuracy: 98.0000%, 
		Validation : Loss : 0.1435, Accuracy: 96.0000%, Time: 101.8337s
Epoch: 28/30
Epoch : 027, Training: Loss: 0.0690, Accuracy: 97.3333%, 
		Validation : Loss : 0.1145, Accuracy: 97.0000%, Time: 101.6800s
Epoch: 29/30
Epoch : 028, Training: Loss: 0.0428, Accuracy: 98.8333%, 
		Validation : Loss : 0.1623, Accuracy: 94.0000%, Time: 111.7447s
Epoch: 30/30
Epoch : 029, Training: Loss: 0.0361, Accuracy: 99.3333%, 
		Validation : Loss : 0.1487, Accuracy: 96.0000%, Time: 102.9441s

6、 驗證

隨着訓練的進行,該模型往往過度擬合數據,導致其在新的測試數據上的性能差。維護一個單獨的驗證集是很重要的,這樣我們就可以在正確的點停止訓練,並防止過度擬合。在每個輪次中,驗證在訓練循環之後立即進行。由於我們在驗證過程中不需要任何梯度計算,所以它是在torch.no_grad()塊中完成的。 對於每個驗證批次,輸入和標籤被轉移到GPU(如果Cuda可用,否則爲CPU,本例中使用cpu,實際不發生轉移)。輸入經過前向傳播,然後對批處理的損失和精度進行計算,循環結束時爲整個輪次計算損失和精度。關於以下代碼中用到的列表的分片和matplotlib繪圖操作,請自行百度。

history = np.array(history)
plt.plot(history[:,0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()
plt.plot(history[:,2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0,1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

Loss curve for training and validation

訓練和驗證的損失曲線

Accuracy curve for training and validation

訓練和驗證的精度曲線

正如我們在上面的圖表中所看到的,對於這個數據集,驗證和訓練損失都很快就降低了。精度提高到0.9的水平也非常快。隨着訓練輪次的增加,訓練損失進一步減少,但驗證結果沒有很大提高,出現了過度擬合。因此,我們從具有較高精度和較低損耗的時刻選擇了模型。如果我們早點停下來,也更好,可以防止過度擬合訓練數據。在我們的例子中,我們選擇了具有96%的驗證精度的第8批次結果。 提取停止訓練也可以自動化。一旦損失低於給定的閾值,並且經過某些輪次的訓練,驗證精度沒有提高,我們就可以停止。

7、推理

一旦我們有了模型,我們就可以對單個測試圖像進行推理,或者在整個測試數據集上進行推理,以獲得測試精度。測試集精度計算與驗證代碼相似,但在測試數據集上進行。讓我們看看下面如何找到給定測試圖像的輸出類。 輸入圖像首先經過用於驗證/測試數據的所有轉換。然後將得到的張量轉換爲4維張量送給模型,輸出不同類的對數概率。模型輸出經進一步指數運算爲我們提供了每個類的概率,然後我們選擇概率最高的類作爲我們的輸出類。

def computeTestSetAccuracy(model, loss_criterion):
    '''
    Function to compute the accuracy on the test set
    Parameters
        :param model: Model to test
        :param loss_criterion: Loss Criterion to minimize
    '''

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    test_acc = 0.0
    test_loss = 0.0

    # Validation - No gradient tracking needed
    with torch.no_grad():

        # Set to evaluation mode
        model.eval()

        # Validation loop
        for j, (inputs, labels) in enumerate(test_data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)

            # Compute loss
            loss = loss_criterion(outputs, labels)

            # Compute the total loss for the batch and add it to valid_loss
            test_loss += loss.item() * inputs.size(0)

            # Calculate validation accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            # Compute total accuracy in the whole batch and add to valid_acc
            test_acc += acc.item() * inputs.size(0)

            print("Test Batch number: {:03d}, Test: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))

    # Find average test loss and test accuracy
    avg_test_loss = test_loss/test_data_size 
    avg_test_acc = test_acc/test_data_size

    print("Test accuracy : " + str(avg_test_acc))

結果如下:

Test Batch number: 000, Test: Loss: 0.1056, Accuracy: 0.9688
Test Batch number: 001, Test: Loss: 0.1419, Accuracy: 0.9375
Test Batch number: 002, Test: Loss: 0.1524, Accuracy: 0.9688
Test Batch number: 003, Test: Loss: 0.0064, Accuracy: 1.0000
Test Batch number: 004, Test: Loss: 0.1434, Accuracy: 0.9375
Test Batch number: 005, Test: Loss: 0.0590, Accuracy: 1.0000
Test Batch number: 006, Test: Loss: 0.3355, Accuracy: 0.8750
Test Batch number: 007, Test: Loss: 0.2506, Accuracy: 0.9375
Test Batch number: 008, Test: Loss: 0.2860, Accuracy: 0.9062
Test Batch number: 009, Test: Loss: 0.1077, Accuracy: 0.9375
Test Batch number: 010, Test: Loss: 0.1335, Accuracy: 0.9375
Test Batch number: 011, Test: Loss: 0.2154, Accuracy: 0.9688
Test Batch number: 012, Test: Loss: 0.3630, Accuracy: 0.8800
Test accuracy : 0.9437652808821289Test loss : 0.17376016186035642

 在409幅圖像的測試集上達到了94.3%的精度。

接下來對每一張圖片單獨進行推理。

def predict(model, test_image_name):
    '''
    Function to predict the class of a single test image
    Parameters
        :param model: Model to test
        :param test_image_name: Test image

    '''
    
    transform = image_transforms['test']

    test_image = Image.open(test_image_name)
    plt.imshow(test_image)
    
    test_image_tensor = transform(test_image)

    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        start = time.time()
        out = model(test_image_tensor)
        stop = time.time()
		print('cost time', stop-start)
        ps = torch.exp(out)
        topk, topclass = ps.topk(3, dim=1)
        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[topclass.cpu().numpy()[0][i]], ", Score: ", topk.cpu().numpy()[0][i])
model = torch.load('caltech_10_model_8.pt')
predict(model, 'caltech_10/test/zebra/250_0091.jpg')

結果如下:

>>> predict(model, 'caltech_10/test/zebra/250_0091.jpg')
cost time 0.5529406070709229
Predcition 1 : zebra , Score:  0.9980318
Predcition 2 : triceratops , Score:  0.0012739539
Predcition 3 : giraffe , Score:  0.00028341665

你可以多做一些實驗,你會發現,概率最高的類通常是正確的。還要注意的是,概率第二高的類在外觀上往往是最接近實際類的動物。

本篇實驗參考連接:https://www.learnopencv.com/image-classification-using-transfer-learning-in-pytorch/

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章