學習筆記|Pytorch使用教程32
本學習筆記主要摘自“深度之眼”,做一個總結,方便查閱。
使用Pytorch版本爲1.2
- 圖像分割是什麼?
- 模型是如何將圖像分割的?
- 深度學習圖像分割模型簡介
- 訓練Unet完成人像摳圖
一.圖像分割是什麼?
圖像分割:將圖像每一個像素分類
1.超像素分割:少量超像素代替大量像素,常用於圖像預處理
2. 語義分割:逐像素分類,無法區分個體
3. 實例分割:對個體目標進行分割,像素級目標檢測
4. 全景分割:語義分割結合實例分割
二.模型是如何將圖像分割的?
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
path_img = os.path.join(BASE_DIR, "demo_img1.png")
# path_img = os.path.join(BASE_DIR, "demo_img2.png")
# path_img = os.path.join(BASE_DIR, "demo_img3.png")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
model.eval()
# 2. preprocess
input_tensor = preprocess(input_image)
input_bchw = input_tensor.unsqueeze(0)
# 3. to device
if torch.cuda.is_available():
input_bchw = input_bchw.to(device)
model.to(device)
# 4. forward
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_bchw.shape))
output_4d = model(input_bchw)['out']
output = output_4d[0]
print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
print("output img tensor shape:{}".format(output.shape))
output_predictions = output.argmax(0)
# 5. visualization
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)
plt.subplot(121).imshow(r)
plt.subplot(122).imshow(input_image)
plt.show()
# appendix
classes = ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
輸出:
input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 21.773s use: cpu
output img tensor shape:torch.Size([21, 433, 649])
21是表示可以分割21個類別,其中一個是背景類。
查看下一個類別:path_img = os.path.join(BASE_DIR, "demo_img2.png")
輸出:
input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 20.287s use: cpu
output img tensor shape:torch.Size([21, 433, 649])
查看第三張圖片:path_img = os.path.join(BASE_DIR, "demo_img3.png")
輸出:
input img tensor shape:torch.Size([1, 3, 730, 574])
pass: 24.351s use: cpu
output img tensor shape:torch.Size([21, 730, 574])
三.深度學習圖像分割模型簡介
模型如何完成圖像分割?
- 答:圖像分割由模型與人類配合完成
- 模型:將數據映射到特徵
- 人類:定義特徵的物理意義,解決實際問題
PyTorch-Hub——PyTorch模型庫,有大量模型供開發者調用
1.torch.hub.load(‘pytorch/vision’, ‘deeplabv3_resnet101’,pretrained=True)
model = torch.hub.load(github, model, *args, **kwargs)
功能:加載模型
主要參數: - github:str, 項目名,eg:pytorch/vision<repo_owner/repo_name[:tag_name]>
- model: str, 模型名
2.torch.hub.list(github, force_reload=False)
3.torch.hub.help(github, model, force_reload=False)
圖像分割的思考
Ps:藍色爲小貓,綠色爲小狗
深度學習中的圖像分割模型
《Fully Convolutional Networks for Semantic Segmentation 》
最主要貢獻:
- 利用全卷積完成pixelwise prediction
《U-Net: Convolutional Networks for Biomedical Image Segmentation》
最主要貢獻:
- 奠定Unet系列分割模型的
- 基本結構 ——編碼器與解碼器的特徵融合
- https://github.com/shawnbit/unet-family
《DeepLabv1 Semantic image segmentation with deep convolutional nets and fully connected CRFs》
DeepLab系列——V1
主要特點:
- 孔洞卷積:藉助孔洞卷積,增大感受野
- CRF:採用CRF進行mask後處理
《DeepLab- Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs》
DeepLab系列——V2
主要特點:
- ASPP(Atrous spatial pyramid pooling ):解決多尺度問題
《DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation》
DeepLab系列——V3
主要特點:
- 1.孔洞卷積的串行
- 2.ASPP的並行
《DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation》
DeepLab系列——V3+
主要特點:
- deeplabv3基礎上加上Encoder-Decoder思想
《Deep Semantic Segmentation of Natural and Medical Images: A Review》2019
圖像分割資源:
https://github.com/shawnbit/unet-family
https://github.com/yassouali/pytorch_segmentation
四.訓練Unet完成人像摳圖
- 數據來源:https://github.com/PetroWu/AutoPortraitMatting
測試代碼:
# -*- coding: utf-8 -*-
"""
# @file name : unet_portrait_matting.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2019-11-25
# @brief : train unet
"""
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
import random
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 設置隨機種子
def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
if __name__ == "__main__":
# config
LR = 0.01
BATCH_SIZE = 8
max_epoch = 1 # 400
start_epoch = 0
lr_step = 150
val_interval = 3
checkpoint_interval = 20
vis_num = 10
mask_thres = 0.5
train_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "train")
valid_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
# step 1
train_set = PortraitDataset(train_dir)
valid_set = PortraitDataset(valid_dir)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)
# step 2
net = UNet(in_channels=3, out_channels=1, init_features=64) # init_features is 64 in stander uent
net.to(device)
# step 3
loss_fn = nn.MSELoss()
# step 4
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)
# step 5
train_curve = list()
valid_curve = list()
train_dice_curve = list()
valid_dice_curve = list()
for epoch in range(start_epoch, max_epoch):
train_loss_total = 0.
train_dice_total = 0.
net.train()
for iter, (inputs, labels) in enumerate(train_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
# forward
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# print
train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
train_dice_curve.append(train_dice)
train_curve.append(loss.item())
train_loss_total += loss.item()
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
"running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
train_loss_total/(iter+1), train_dice, scheduler.get_lr()))
scheduler.step()
if (epoch + 1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
# validate the model
if (epoch+1) % val_interval == 0:
net.eval()
valid_loss_total = 0.
valid_dice_total = 0.
with torch.no_grad():
for j, (inputs, labels) in enumerate(valid_loader):
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
valid_loss_total += loss.item()
valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
valid_dice_total += valid_dice
valid_loss_mean = valid_loss_total/len(valid_loader)
valid_dice_mean = valid_dice_total/len(valid_loader)
valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
epoch, max_epoch, valid_loss_mean, valid_dice_mean))
# 可視化
with torch.no_grad():
for idx, (inputs, labels) in enumerate(valid_loader):
if idx > vis_num:
break
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
plt.pause(0.5)
plt.close()
# plot curve
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
# dice curve
train_x = range(len(train_dice_curve))
train_y = train_dice_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_dice_curve) + 1) * train_iters * val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_dice_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('dice value')
plt.xlabel('Iteration')
plt.title("Plot in {} epochs".format(max_epoch))
plt.show()
torch.cuda.empty_cache()
測試一個epoch,輸出:
Training:Epoch[000/001] Iteration[001/212] running_loss: 0.2455, mean_loss: 0.2455 running_dice: 0.6275 lr:[0.01]
Training:Epoch[000/001] Iteration[002/212] running_loss: 0.2436, mean_loss: 0.2445 running_dice: 0.6337 lr:[0.01]
......
Training:Epoch[000/001] Iteration[210/212] running_loss: 0.0816, mean_loss: 0.1595 running_dice: 0.9295 lr:[0.01]
Training:Epoch[000/001] Iteration[211/212] running_loss: 0.1406, mean_loss: 0.1594 running_dice: 0.8416 lr:[0.01]
Training:Epoch[000/001] Iteration[212/212] running_loss: 0.1624, mean_loss: 0.1594 running_dice: 0.8296 lr:[0.01]
查看Unet結構。雖然簡單,但很經典。
from collections import OrderedDict
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(
in_channels=features, out_channels=out_channels, kernel_size=1
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv2d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm2d(num_features=features)),
(name + "relu1", nn.ReLU(inplace=True)),
(
name + "conv2",
nn.Conv2d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm2d(num_features=features)),
(name + "relu2", nn.ReLU(inplace=True)),
]
)
)
現在使用訓練過400次epoch的權重進行測試:
(注意這裏使用的feature=32)
import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 設置隨機種子
def compute_dice(y_pred, y_true):
"""
:param y_pred: 4-d tensor, value = [0,1]
:param y_true: 4-d tensor, value = [0,1]
:return:
"""
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
def get_img_name(img_dir, format="jpg"):
"""
獲取文件夾下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
img_names = list(filter(lambda x: x.endswith(format), file_names))
img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式數據".format(img_dir, format))
return img_names
def get_model(m_path):
unet = UNet(in_channels=3, out_channels=1, init_features=32)
checkpoint = torch.load(m_path, map_location="cpu")
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
unet.load_state_dict(new_state_dict)
return unet
if __name__ == "__main__":
img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
model_path = "checkpoint_399_epoch.pkl"
time_total = 0
num_infer = 5
mask_thres = .5
# 1. data
img_names = get_img_name(img_dir, format="png")
random.shuffle(img_names)
num_img = len(img_names)
# 2. model
unet = get_model(model_path)
unet.to(device)
unet.eval()
for idx, img_name in enumerate(img_names):
if idx > num_infer:
break
path_img = os.path.join(img_dir, img_name)
# path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
#
# step 1/4 : path --> img_chw
img_hwc = Image.open(path_img).convert('RGB')
img_hwc = img_hwc.resize((224, 224))
img_arr = np.array(img_hwc)
img_chw = img_arr.transpose((2, 0, 1))
# step 2/4 : img --> tensor
img_tensor = torch.tensor(img_chw).to(torch.float)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> features
time_tic = time.time()
outputs = unet(img_tensor)
time_toc = time.time()
# step 4/4 : visualization
pred = outputs.ge(mask_thres)
mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")
img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
plt.subplot(121).imshow(img_hwc)
mask_pred_gray = mask_pred.squeeze() * 255
plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
plt.show()
# plt.pause(0.5)
plt.close()
time_s = time_toc - time_tic
time_total += time_s
print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))
輸出: