學習筆記|Pytorch使用教程32(圖像分割一瞥)

學習筆記|Pytorch使用教程32

本學習筆記主要摘自“深度之眼”,做一個總結,方便查閱。
使用Pytorch版本爲1.2

  • 圖像分割是什麼?
  • 模型是如何將圖像分割的?
  • 深度學習圖像分割模型簡介
  • 訓練Unet完成人像摳圖

一.圖像分割是什麼?

圖像分割:將圖像每一個像素分類
在這裏插入圖片描述
1.超像素分割:少量超像素代替大量像素,常用於圖像預處理
2. 語義分割:逐像素分類,無法區分個體
3. 實例分割:對個體目標進行分割,像素級目標檢測
4. 全景分割:語義分割結合實例分割

在這裏插入圖片描述

二.模型是如何將圖像分割的?

在這裏插入圖片描述

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":

    path_img = os.path.join(BASE_DIR, "demo_img1.png")
    # path_img = os.path.join(BASE_DIR, "demo_img2.png")
    # path_img = os.path.join(BASE_DIR, "demo_img3.png")

    # config
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 1. load data & model
    input_image = Image.open(path_img).convert("RGB")
    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    model.eval()

    # 2. preprocess
    input_tensor = preprocess(input_image)
    input_bchw = input_tensor.unsqueeze(0)

    # 3. to device
    if torch.cuda.is_available():
        input_bchw = input_bchw.to(device)
        model.to(device)

    # 4. forward
    with torch.no_grad():
        tic = time.time()
        print("input img tensor shape:{}".format(input_bchw.shape))
        output_4d = model(input_bchw)['out']
        output = output_4d[0]
        print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
        print("output img tensor shape:{}".format(output.shape))
    output_predictions = output.argmax(0)

    # 5. visualization
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    # plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    plt.subplot(121).imshow(r)
    plt.subplot(122).imshow(input_image)
    plt.show()

    # appendix
    classes = ['__background__',
                       'aeroplane', 'bicycle', 'bird', 'boat',
                       'bottle', 'bus', 'car', 'cat', 'chair',
                       'cow', 'diningtable', 'dog', 'horse',
                       'motorbike', 'person', 'pottedplant',
                       'sheep', 'sofa', 'train', 'tvmonitor']

輸出:

input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 21.773s use: cpu
output img tensor shape:torch.Size([21, 433, 649])

21是表示可以分割21個類別,其中一個是背景類。
在這裏插入圖片描述
查看下一個類別:path_img = os.path.join(BASE_DIR, "demo_img2.png")
輸出:

input img tensor shape:torch.Size([1, 3, 433, 649])
pass: 20.287s use: cpu
output img tensor shape:torch.Size([21, 433, 649])

在這裏插入圖片描述
查看第三張圖片:path_img = os.path.join(BASE_DIR, "demo_img3.png")
輸出:

input img tensor shape:torch.Size([1, 3, 730, 574])
pass: 24.351s use: cpu
output img tensor shape:torch.Size([21, 730, 574])

在這裏插入圖片描述

三.深度學習圖像分割模型簡介

模型如何完成圖像分割?

  • 答:圖像分割由模型與人類配合完成
  • 模型:將數據映射到特徵
  • 人類:定義特徵的物理意義,解決實際問題
    在這裏插入圖片描述
    PyTorch-Hub——PyTorch模型庫,有大量模型供開發者調用
    1.torch.hub.load(‘pytorch/vision’, ‘deeplabv3_resnet101’,pretrained=True)
    model = torch.hub.load(github, model, *args, **kwargs)
    功能:加載模型
    主要參數:
  • github:str, 項目名,eg:pytorch/vision<repo_owner/repo_name[:tag_name]>
  • model: str, 模型名

2.torch.hub.list(github, force_reload=False)
3.torch.hub.help(github, model, force_reload=False)

圖像分割的思考
在這裏插入圖片描述在這裏插入圖片描述
Ps:藍色爲小貓,綠色爲小狗

深度學習中的圖像分割模型

Fully Convolutional Networks for Semantic Segmentation
最主要貢獻:

  • 利用全卷積完成pixelwise prediction

在這裏插入圖片描述
U-Net: Convolutional Networks for Biomedical Image Segmentation
最主要貢獻:

  • 奠定Unet系列分割模型的
  • 基本結構 ——編碼器與解碼器的特徵融合
  • https://github.com/shawnbit/unet-family

在這裏插入圖片描述
在這裏插入圖片描述
DeepLabv1 Semantic image segmentation with deep convolutional nets and fully connected CRFs
DeepLab系列——V1
主要特點:

  • 孔洞卷積:藉助孔洞卷積,增大感受野
  • CRF:採用CRF進行mask後處理

在這裏插入圖片描述

DeepLab- Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs
DeepLab系列——V2
主要特點:

  • ASPP(Atrous spatial pyramid pooling ):解決多尺度問題

在這裏插入圖片描述

DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation
DeepLab系列——V3
主要特點:

  • 1.孔洞卷積的串行
  • 2.ASPP的並行

在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述
DeepLabv3- Rethinking Atrous Convolution for Semantic Image Segmentation
DeepLab系列——V3+
主要特點:

  • deeplabv3基礎上加上Encoder-Decoder思想

在這裏插入圖片描述
Deep Semantic Segmentation of Natural and Medical Images: A Review》2019
在這裏插入圖片描述
圖像分割資源:
https://github.com/shawnbit/unet-family
https://github.com/yassouali/pytorch_segmentation

四.訓練Unet完成人像摳圖

在這裏插入圖片描述

  • 數據來源:https://github.com/PetroWu/AutoPortraitMatting

測試代碼:

# -*- coding: utf-8 -*-
"""
# @file name  : unet_portrait_matting.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-11-25
# @brief      : train unet
"""

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet
import random

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed()  # 設置隨機種子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


if __name__ == "__main__":

    # config
    LR = 0.01
    BATCH_SIZE = 8
    max_epoch = 1   # 400
    start_epoch = 0
    lr_step = 150
    val_interval = 3
    checkpoint_interval = 20
    vis_num = 10
    mask_thres = 0.5

    train_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "train")
    valid_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")

    # step 1
    train_set = PortraitDataset(train_dir)
    valid_set = PortraitDataset(valid_dir)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

    # step 2
    net = UNet(in_channels=3, out_channels=1, init_features=64)   # init_features is 64 in stander uent
    net.to(device)

    # step 3
    loss_fn = nn.MSELoss()
    # step 4
    optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)

    # step 5
    train_curve = list()
    valid_curve = list()
    train_dice_curve = list()
    valid_dice_curve = list()
    for epoch in range(start_epoch, max_epoch):

        train_loss_total = 0.
        train_dice_total = 0.

        net.train()
        for iter, (inputs, labels) in enumerate(train_loader):

            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)

            # forward
            outputs = net(inputs)

            # backward
            optimizer.zero_grad()
            loss = loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()

            # print
            train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
            train_dice_curve.append(train_dice)
            train_curve.append(loss.item())

            train_loss_total += loss.item()

            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
                  "running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
                                    train_loss_total/(iter+1), train_dice, scheduler.get_lr()))

        scheduler.step()

        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint = {"model_state_dict": net.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
            torch.save(checkpoint, path_checkpoint)

        # validate the model
        if (epoch+1) % val_interval == 0:

            net.eval()
            valid_loss_total = 0.
            valid_dice_total = 0.

            with torch.no_grad():
                for j, (inputs, labels) in enumerate(valid_loader):
                    if torch.cuda.is_available():
                        inputs, labels = inputs.to(device), labels.to(device)

                    outputs = net(inputs)
                    loss = loss_fn(outputs, labels)

                    valid_loss_total += loss.item()

                    valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
                    valid_dice_total += valid_dice

                valid_loss_mean = valid_loss_total/len(valid_loader)
                valid_dice_mean = valid_dice_total/len(valid_loader)
                valid_curve.append(valid_loss_mean)
                valid_dice_curve.append(valid_dice_mean)

                print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
                    epoch, max_epoch, valid_loss_mean, valid_dice_mean))

    # 可視化
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(valid_loader):
            if idx > vis_num:
                break
            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            pred = outputs.ge(mask_thres)

            mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

            img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
            plt.subplot(121).imshow(img_hwc)
            mask_pred_gray = mask_pred.squeeze() * 255
            plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
            plt.show()
            plt.pause(0.5)
            plt.close()

    # plot curve
    train_x = range(len(train_curve))
    train_y = train_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_curve) + 1) * train_iters * val_interval  # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
    valid_y = valid_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()

    # dice curve
    train_x = range(len(train_dice_curve))
    train_y = train_dice_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_dice_curve) + 1) * train_iters * val_interval  # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
    valid_y = valid_dice_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('dice value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()
    torch.cuda.empty_cache()

測試一個epoch,輸出:

Training:Epoch[000/001] Iteration[001/212] running_loss: 0.2455, mean_loss: 0.2455 running_dice: 0.6275 lr:[0.01]
Training:Epoch[000/001] Iteration[002/212] running_loss: 0.2436, mean_loss: 0.2445 running_dice: 0.6337 lr:[0.01]
......
Training:Epoch[000/001] Iteration[210/212] running_loss: 0.0816, mean_loss: 0.1595 running_dice: 0.9295 lr:[0.01]
Training:Epoch[000/001] Iteration[211/212] running_loss: 0.1406, mean_loss: 0.1594 running_dice: 0.8416 lr:[0.01]
Training:Epoch[000/001] Iteration[212/212] running_loss: 0.1624, mean_loss: 0.1594 running_dice: 0.8296 lr:[0.01]

在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述
查看Unet結構。雖然簡單,但很經典。

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

現在使用訓練過400次epoch的權重進行測試:
(注意這裏使用的feature=32)

import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
#from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed()  # 設置隨機種子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


def get_img_name(img_dir, format="jpg"):
    """
    獲取文件夾下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))
    img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式數據".format(img_dir, format))
    return img_names


def get_model(m_path):

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    checkpoint = torch.load(m_path, map_location="cpu")

    # remove module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    unet.load_state_dict(new_state_dict)

    return unet


if __name__ == "__main__":

    img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
    model_path = "checkpoint_399_epoch.pkl"
    time_total = 0
    num_infer = 5
    mask_thres = .5

    # 1. data
    img_names = get_img_name(img_dir, format="png")
    random.shuffle(img_names)
    num_img = len(img_names)

    # 2. model
    unet = get_model(model_path)
    unet.to(device)
    unet.eval()

    for idx, img_name in enumerate(img_names):
        if idx > num_infer:
            break

        path_img = os.path.join(img_dir, img_name)
        # path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
        #
        # step 1/4 : path --> img_chw
        img_hwc = Image.open(path_img).convert('RGB')
        img_hwc = img_hwc.resize((224, 224))
        img_arr = np.array(img_hwc)
        img_chw = img_arr.transpose((2, 0, 1))

        # step 2/4 : img --> tensor
        img_tensor = torch.tensor(img_chw).to(torch.float)
        img_tensor.unsqueeze_(0)
        img_tensor = img_tensor.to(device)

        # step 3/4 : tensor --> features
        time_tic = time.time()
        outputs = unet(img_tensor)
        time_toc = time.time()

        # step 4/4 : visualization
        pred = outputs.ge(mask_thres)
        mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

        img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
        plt.subplot(121).imshow(img_hwc)
        mask_pred_gray = mask_pred.squeeze() * 255
        plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
        plt.show()
        # plt.pause(0.5)
        plt.close()

        time_s = time_toc - time_tic
        time_total += time_s

        print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

輸出:
在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章