接上篇:PyTorch框架實戰系列(1)——CNN圖像分類器
對PyTorch教程圖像分類器進行優化:(不涉及GPU訓練,所以沒寫可GPU訓練的代碼)
1、CNN(卷積神經網絡)增加了網絡深度,卷積層逐層對特徵進行提取,從微小特徵總結爲較大特徵,最後由全連接層進行仿射變換。以下是模型架構:
Model(
(block): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(9): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
(12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classer): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=512, out_features=10, bias=True)
)
)
2、對訓練過程進行優化,增加訓練次數,並且每在一定數量樣本訓練後,將模型在測試集上進行驗證,當測試集損失值不再收斂時自動停止訓練。
3、增加模型測試評估報告,包括準確率、混淆矩陣和各類別的精確度、召回率及F1評分。
CNN.py
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128, affine=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.classer = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(in_features=512, out_features=10),
)
def forward(self, x):
x = self.block(x) # [1, 128, 2, 2]
x = x.view(x.size()[0], -1) # [1, 512]
x = self.classer(x) # [1, 10]
# print(x.size())
return x
if __name__ == '__main__':
import torch
from torch.autograd import Variable
x = Variable(torch.rand(1, 3, 32, 32))
model = Model()
print(model)
y = model(x)
print(y)
train.py
# -*- coding: utf-8 -*-
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from CNN import Model
import torch
import numpy as np
from sklearn import metrics
def train(save_path, model, trainloader, testloader):
# 訓練模式
model.train()
# 指定損失函數和優化器,學習率0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
total_batch = 0 # 記錄進行到多少batch
dev_best_loss = float('inf') # 記錄驗證集最佳損失率
last_improve = 0 # 記錄上次驗證集loss下降的batch數
flag = False # 記錄是否很久沒有效果提升
# 批次訓練
for epoch in range(20):
print('Epoch [{}/{}]'.format(epoch + 1, 20))
# 從迭代器中按mini-batch訓練
for trains, labels in trainloader:
outputs = model(trains)
# 模型梯度歸零
model.zero_grad()
# 損失函數反饋傳遞
loss = criterion(outputs, labels)
loss.backward()
# 執行優化
optimizer.step()
# 每多少輪在測試集上查看訓練的效果
if total_batch % 100 == 0:
# 獲得訓練集準確率
true = labels.data
predic = torch.max(outputs.data, 1)[1]
train_acc = metrics.accuracy_score(true, predic)
# 如果驗證集上繼續收斂則保存模型參數
dev_acc, dev_loss = evaluate(model, testloader)
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
torch.save(model.state_dict(), save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
# 訓練成果
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%} {5}'
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, improve))
# 恢復訓練
model.train()
total_batch += 1
# 驗證集loss超過多少batch沒下降,結束訓練
if total_batch - last_improve > 500:
print("Finished Training...")
# torch.save(model.state_dict(), save_path)
flag = True
break
if flag:
break
# 使用測試集測試評估模型
model_test(save_path, model, testloader)
# 驗證模型
def evaluate(model, dataloader, test=False):
class_list = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 模型預測模式
model.eval()
loss_total = 0
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
loss_func = nn.CrossEntropyLoss()
# 不記錄模型梯度
with torch.no_grad():
for texts, labels in dataloader:
outputs = model(texts)
loss = loss_func(outputs, labels)
loss_total += loss
labels = labels.data.numpy()
predic = torch.max(outputs.data, 1)[1].numpy()
labels_all = np.append(labels_all, labels)
predict_all = np.append(predict_all, predic)
# 驗證集準確度
acc = metrics.accuracy_score(labels_all, predict_all)
# 給出模型測試結果,評估報告和混淆矩陣
if test:
report = metrics.classification_report(labels_all, predict_all, target_names=class_list, digits=4)
confusion = metrics.confusion_matrix(labels_all, predict_all)
return acc, loss_total / len(dataloader), report, confusion
else:
return acc, loss_total / len(dataloader)
# 測試模型
def model_test(save_path, model, testloader):
# 加載模型參數
model.load_state_dict(torch.load(save_path))
# 預測模式
model.eval()
# 模型測試評估
test_acc, test_loss, test_report, test_confusion = evaluate(model, testloader, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("混淆矩陣...")
print(test_confusion)
print("評估報告...")
print(test_report)
if __name__ == '__main__':
data_path = './data' # 數據集
save_path = './cnn_model.pth' # 模型保存路徑
# 數據集轉化爲張量,並標準化
# input[channel] = (input[channel] - mean[channel]) / std[channel]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# transform = None
# 下載數據集
trainset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)
# 查看數據集大小
print('trainset', len(trainset))
print('testset', len(testset))
batch_size = 100 # mini-batch
# 構造迭代器
trainloader = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=True)
# 迭代器輸出的張量
for sample, label in trainloader:
print(sample.size(), label.size())
break
# 模型實例化
model = Model()
train(save_path, model, trainloader, testloader)
輸出結果如下:
Files already downloaded and verified
Files already downloaded and verified
trainset 50000
testset 10000
torch.Size([100, 3, 32, 32]) torch.Size([100])
Epoch [1/20]
Iter: 0, Train Loss: 2.5, Train Acc: 14.00%, Val Loss: 2.3, Val Acc: 13.20% *
Iter: 100, Train Loss: 1.5, Train Acc: 42.00%, Val Loss: 1.5, Val Acc: 44.82% *
Iter: 200, Train Loss: 1.4, Train Acc: 47.00%, Val Loss: 1.3, Val Acc: 52.96% *
Iter: 300, Train Loss: 1.2, Train Acc: 57.00%, Val Loss: 1.3, Val Acc: 53.33%
Iter: 400, Train Loss: 1.1, Train Acc: 61.00%, Val Loss: 1.1, Val Acc: 60.16% *
Epoch [2/20]
Iter: 500, Train Loss: 0.95, Train Acc: 60.00%, Val Loss: 1.0, Val Acc: 63.07% *
Iter: 600, Train Loss: 1.0, Train Acc: 63.00%, Val Loss: 1.0, Val Acc: 62.91%
Iter: 700, Train Loss: 0.77, Train Acc: 69.00%, Val Loss: 1.0, Val Acc: 64.62% *
Iter: 800, Train Loss: 0.94, Train Acc: 63.00%, Val Loss: 1.1, Val Acc: 62.55%
Iter: 900, Train Loss: 0.93, Train Acc: 63.00%, Val Loss: 0.96, Val Acc: 65.69% *
Epoch [3/20]
Iter: 1000, Train Loss: 0.83, Train Acc: 68.00%, Val Loss: 0.97, Val Acc: 65.84%
Iter: 1100, Train Loss: 0.77, Train Acc: 71.00%, Val Loss: 0.94, Val Acc: 67.48% *
Iter: 1200, Train Loss: 1.0, Train Acc: 65.00%, Val Loss: 0.86, Val Acc: 69.97% *
Iter: 1300, Train Loss: 0.77, Train Acc: 74.00%, Val Loss: 0.86, Val Acc: 70.56%
Iter: 1400, Train Loss: 0.91, Train Acc: 69.00%, Val Loss: 0.87, Val Acc: 69.61%
Epoch [4/20]
Iter: 1500, Train Loss: 0.56, Train Acc: 79.00%, Val Loss: 0.85, Val Acc: 70.40% *
Iter: 1600, Train Loss: 0.62, Train Acc: 74.00%, Val Loss: 0.76, Val Acc: 73.41% *
Iter: 1700, Train Loss: 0.7, Train Acc: 78.00%, Val Loss: 0.78, Val Acc: 73.04%
Iter: 1800, Train Loss: 0.75, Train Acc: 80.00%, Val Loss: 0.78, Val Acc: 72.59%
Iter: 1900, Train Loss: 0.89, Train Acc: 63.00%, Val Loss: 0.77, Val Acc: 72.92%
Epoch [5/20]
Iter: 2000, Train Loss: 0.58, Train Acc: 81.00%, Val Loss: 0.75, Val Acc: 73.93% *
Iter: 2100, Train Loss: 0.7, Train Acc: 75.00%, Val Loss: 0.77, Val Acc: 73.61%
Iter: 2200, Train Loss: 0.85, Train Acc: 66.00%, Val Loss: 0.8, Val Acc: 72.10%
Iter: 2300, Train Loss: 0.67, Train Acc: 78.00%, Val Loss: 0.74, Val Acc: 74.53% *
Iter: 2400, Train Loss: 0.86, Train Acc: 76.00%, Val Loss: 0.75, Val Acc: 73.82%
Epoch [6/20]
Iter: 2500, Train Loss: 0.78, Train Acc: 72.00%, Val Loss: 0.8, Val Acc: 72.36%
Iter: 2600, Train Loss: 0.65, Train Acc: 76.00%, Val Loss: 0.75, Val Acc: 74.33%
Iter: 2700, Train Loss: 0.66, Train Acc: 81.00%, Val Loss: 0.74, Val Acc: 74.53%
Iter: 2800, Train Loss: 0.66, Train Acc: 75.00%, Val Loss: 0.72, Val Acc: 75.29% *
Iter: 2900, Train Loss: 0.75, Train Acc: 74.00%, Val Loss: 0.87, Val Acc: 70.47%
Epoch [7/20]
Iter: 3000, Train Loss: 0.6, Train Acc: 77.00%, Val Loss: 0.69, Val Acc: 75.98% *
Iter: 3100, Train Loss: 0.57, Train Acc: 83.00%, Val Loss: 0.7, Val Acc: 76.36%
Iter: 3200, Train Loss: 0.54, Train Acc: 78.00%, Val Loss: 0.72, Val Acc: 75.94%
Iter: 3300, Train Loss: 0.55, Train Acc: 81.00%, Val Loss: 0.67, Val Acc: 77.06% *
Iter: 3400, Train Loss: 0.58, Train Acc: 76.00%, Val Loss: 0.7, Val Acc: 75.97%
Epoch [8/20]
Iter: 3500, Train Loss: 0.47, Train Acc: 85.00%, Val Loss: 0.69, Val Acc: 76.08%
Iter: 3600, Train Loss: 0.69, Train Acc: 78.00%, Val Loss: 0.74, Val Acc: 74.58%
Iter: 3700, Train Loss: 0.68, Train Acc: 83.00%, Val Loss: 0.73, Val Acc: 75.18%
Iter: 3800, Train Loss: 0.88, Train Acc: 70.00%, Val Loss: 0.69, Val Acc: 76.80%
Finished Training...
Test Loss: 0.67, Test Acc: 77.06%
混淆矩陣...
[[794 31 43 10 9 5 8 15 49 36]
[ 11 918 2 4 3 3 1 0 10 48]
[ 65 6 706 39 71 46 35 18 5 9]
[ 21 10 83 533 60 175 60 32 11 15]
[ 21 2 81 38 747 34 19 51 6 1]
[ 13 4 54 103 43 717 18 38 2 8]
[ 5 5 77 47 38 24 790 6 6 2]
[ 7 3 41 22 50 54 2 809 2 10]
[ 56 41 14 12 5 1 2 2 837 30]
[ 18 85 6 6 4 3 2 8 13 855]]
評估報告...
precision recall f1-score support
plane 0.7854 0.7940 0.7897 1000
car 0.8308 0.9180 0.8722 1000
bird 0.6378 0.7060 0.6701 1000
cat 0.6548 0.5330 0.5877 1000
deer 0.7252 0.7470 0.7360 1000
dog 0.6751 0.7170 0.6954 1000
frog 0.8431 0.7900 0.8157 1000
horse 0.8264 0.8090 0.8176 1000
ship 0.8895 0.8370 0.8624 1000
truck 0.8432 0.8550 0.8491 1000
micro avg 0.7706 0.7706 0.7706 10000
macro avg 0.7711 0.7706 0.7696 10000
weighted avg 0.7711 0.7706 0.7696 10000
可以看到,準確率從55%提升到了77%,優化效果明顯。
還可以使用學習率衰減的方法,應該能夠讓模型更快收斂。
對比項目實踐中圖像分類的準確率,本例的準確率還是比較低的。主要因爲這個數據集32*32的分辨率還是太小了,很多細微特徵無法提取。