基於Unet+opencv實現天空對象的分割、替換和美化

     原文地址:https://www.cnblogs.com/jsxyhelu/p/16995892.html
  
     傳統圖像處理算法進行“天空分割”存在精度問題且調參複雜,無法很好地應對雲霧、陰霾等情況;本篇文章分享的“基於Unet+opencv實現天空對象的分割、替換和美化”,較好地解決了該問題,包括以下內容:
1、基於Unet語義分割的基本原理、環境構建、參數調節等
2、一種有效的天空分割數據集準備方法,並且獲得數據集
3、基於OpenCV的Pytorch模型部署方法
4、融合效果極好的 SeamlessClone 技術
5、飽和度調整、顏色域等基礎圖像處理知識和編碼技術
    本文適合具備 OpenCV 和Pytorch相關基礎,對“天空替換”感興趣的人士。學完本文,可以獲得基於Pytorch和OpenCV進行語義分割、解決實際問題的具體方法,提高環境構建、數據集準備、參數調節和運行部署等方面綜合能力。
 一、傳統方法和語義分割基礎
1.1傳統方法主要通過“顏色域”來進行分割

比如,我們要找的是藍天,那麼在HSV域,就可以通過查表的方法找出藍色區域。 

在這張表中,藍色的HSV的上下門限已經標註出來,我們編碼實現。

    cvtColor(matSrc,temp,COLOR_BGR2HSV);
    split(temp,planes);
    equalizeHist(planes[2],planes[2]);//對v通道進行equalizeHist
    merge(planes,temp);
    inRange(temp,Scalar(100,43,46),Scalar(124,255,255),temp);
    erode(temp,temp,Mat());//形態學變換,填補內部空洞
    dilate(temp,temp,Mat());
    imshow("原始圖",matSrc);

在這段代碼中,有兩個小技巧,一個是對模板(MASK)進行了形態學變化,這個不展開說;一個是我們首先對HSV圖進行了3通道分解,並且直方圖增強V通道,而後將3通道合併回去。通過這種方法能夠增強原圖對比度,讓藍天更藍、青山更青……大家可以自己調試看一下。 顯示處理後識別爲天空的結果(在OpenCV中,白色代表1也就是由數據,黑色代表0也就是沒數據) 

對於天壇這幅圖來說,效果不錯。雖然在右上角錯誤,而塔中間的一個很小的空洞,這些後期都是可以規避掉的錯誤。 

但是對於陰霾圖片來說,由於天空中沒有藍色,識別起來就很錯誤很多。

1.2 語義分割基礎

圖像語義分割(semantic segmentation),從字面意思上理解就是讓計算機根據圖像的語義來進行分割,例如讓計算機在輸入下面左圖的情況下,能夠輸出右圖。語義在語音識別中指的是語音的意思,在圖像領域,語義指的是圖像的內容,對圖片意思的理解,比如左圖的語義就是三個人騎着三輛自行車;分割的意思是從像素的角度分割出圖片中的不同對象,對原圖中的每個像素都進行標註,比如右圖中粉紅色代表人,綠色代表自行車。

那麼對於天空分割問題來說,主要目標就是找到像素級別的天空對象,使用語義分割模型就是有效的。

二、Unet基本情況和環境構建
Unet 發表於 2015 年,屬於 FCN 的一種變體,Unet 的初衷是爲了解決生物醫學圖像方面的問題,由於效果確實很好後來也被廣泛的應用在語義分割的各個方向,比如衛星圖像分割,工業瑕疵檢測等。它也有很多變體,但是對於天空分割問題來看,Unet的能力已經夠了。
Unet 跟 FCN 都是 Encoder-Decoder 結構,結構簡單但很有效。Encoder 負責特徵提取,你可以將自己熟悉的各種特徵提取網絡放在這個位置。由於在醫學方面,樣本收集較爲困難,作者爲了解決這個問題,應用了圖像增強的方法,在數據集有限的情況下獲得了不錯的精度。
 

 

 

如上圖,Unet 網絡結構是對稱的,形似英文字母 U 所以被稱爲 Unet。整張圖都是由藍/白色框與各種顏色的箭頭組成,其中,藍/白色框表示 feature map;藍色箭頭表示 3x3 卷積,用於特徵提取;灰色箭頭表示 skip-connection,用於特徵融合;紅色箭頭表示池化 pooling,用於降低維度;綠色箭頭表示上採樣 upsample,用於恢復維度;青色箭頭表示 1x1 卷積,用於輸出結果。
在環境構建這塊,我建議一定要結合自己的實際情況,構建專用的代碼庫,這樣才能夠通過不斷迭代,在總體正確的前提下形成自己風格。
在我的庫中,基於現有的Unet代碼進行了修改
其中checkpoints、data保持數據;unet是模型的具體實現,未來可以擴充爲多模型;utils是常用函數;alibaba.py和oss2helper.py是阿里雲的輔助函數;export_unet.py是輸出函數;eveluate.py和train.py用於訓練;predict.py用於本地測試;main.py是主要函數。
三、數據集準備和增強
3.1 數據集準備這塊,我採取了增強的方法。由於個人習慣問題,採用的是OpenCV本地變換的方法
   
 getFiles("e:/template/Data_sky/data", fileNames);
    string saveFile = "e:/template/Data_sky/dataEX3/";
    for (int index = 0; index < fileNames.size(); index++)
    {
        Mat src = imread(fileNames[index]);
        Mat dst;
        string fileName;
        getFileName(fileNames[index], fileName);
        resize(src, dst, cv::Size(512, 512));
        imwrite(saveFile + fileName + "_512.jpg", dst);
        resize(src, dst, cv::Size(256, 256));
        imwrite(saveFile + fileName + "_256.jpg", dst);
        resize(src, dst, cv::Size(128, 128));
        imwrite(saveFile + fileName + "_128.jpg", dst);
        cout << fileName << endl;
    }
    fileNames.clear();
    getFiles("e:/template/Data_sky/mask", fileNames);
    saveFile = "e:/template/Data_sky/maskEX3/";
    for (int index = 0; index < fileNames.size(); index++)
    {
        Mat src = imread(fileNames[index], 0);
        Mat dst;
        string fileName;
        getFileName(fileNames[index], fileName);
        fileName = fileName.substr(0, fileName.size() - 3);
        resize(src, dst, cv::Size(512, 512));
        imwrite(saveFile + fileName + "_512_gt.jpg", dst);
        resize(src, dst, cv::Size(256, 256));
        imwrite(saveFile + fileName + "_256_gt.jpg", dst);
        resize(src, dst, cv::Size(128, 128));
        imwrite(saveFile + fileName + "_128_gt.jpg", dst);
        cout << fileName << endl;
    }

 

 

從而獲得不同分辨率的目標數據,但是如何獲得標註數據?我推薦一種方法。
3.2、通過對“阿里視覺智能開放平臺”的研究,調用它的成果來進行訓練。簡單來說,它提供了天空分割的功能,但是要求數據的輸入輸出都保存在oss中,所以需要通過python來編寫腳本。我對這段python代碼進行了一些註釋,放在這裏。
# -*- coding: utf8 -*-
from aliyunsdkcore.client import AcsClient
from aliyunsdkimageseg.request.v20191230 import SegmentSkyRequest
from aliyunsdkimageseg.request.v20191230.SegmentHDSkyRequest import SegmentHDSkyRequest
import oss2
import os
import json
import urllib


# 創建 AcsClient 實例
client = AcsClient("LTAI5tQCCmMyKSfifwsFHLpC", "JyzNfHsCnUaVTeS6Xg3ylMjQFC8C6L", "cn-shanghai")
request = SegmentSkyRequest.SegmentSkyRequest()
endpoint = "https://oss-cn-shanghai.aliyuncs.com"
accesskey_id = "LTAI5tQCCmMyKSfifwsFHLpC"
accesskey_secret = "JyzNfHsCnUaVTeS6Xg3ylMjQFC8C6L"
bucket_name = "datasky2"
bucket_name2 = "viapi-cn-shanghai-dha-segmenter"

#本地文件保存路徑前綴
download_local_save_prefix = "/home/helu/GOPytorchHelper/data/dataOss/"

'''
列舉prefix全部文件
'''
def prefix_all_list(bucket,prefix):
    print("開始列舉"+prefix+"全部文件");
    oss_file_size = 0;
    for obj in oss2.ObjectIterator(bucket, prefix ='%s/'%prefix):
         print(' key : ' + obj.key)
         oss_file_size = oss_file_size + 1;
         download_to_local(bucket, obj.key, obj.key);
    print(prefix +" file size " + str(oss_file_size));


'''
列舉全部的根目錄文件夾、文件
'''
def root_directory_list(bucket):
    # 設置Delimiter參數爲正斜線(/)。
    for obj in oss2.ObjectIterator(bucket, delimiter='/'):
        # 通過is_prefix方法判斷obj是否爲文件夾。
        if obj.is_prefix():  # 文件夾
            print('directory: ' + obj.key);
            prefix_all_list(bucket,str(obj.key).strip("/")); #去除/
        else:  # 文件
            print('file: ' +obj.key)
            # 填寫Object完整路徑,例如exampledir/exampleobject.txt。Object完整路徑中不能包含Bucket名稱。
            object_name = obj.key
            # 生成下載文件的簽名URL,有效時間爲60秒。
            # 生成簽名URL時,OSS默認會對Object完整路徑中的正斜線(/)進行轉義,從而導致生成的簽名URL無法直接使用。
            # 設置slash_safe爲True,OSS不會對Object完整路徑中的正斜線(/)進行轉義,此時生成的簽名URL可以直接使用。
            url = bucket.sign_url('GET', object_name, 60, slash_safe=True)     
            print('簽名url的地址爲:', url)
            ## 如下url替換爲自有的上海region的oss文件地址
            request.set_ImageURL(url)
            response = client.do_action_with_exception(request)
            print('response地址爲:', response)
            user_dict = json.loads(response)
            for name in user_dict.keys():
                if(name.title() == "Data"):
                    inner_dict = user_dict[name]
                    for innerName in inner_dict.keys():
                        if(innerName == "ImageURL"):
                            finalName = inner_dict[innerName]
                            print('finalName地址爲:',str(finalName))
                            urllib.request.urlretrieve(str(finalName), download_local_save_prefix+obj.key)
'''
下載文件到本地
'''
def download_to_local(bucket,object_name,local_file):
    url = download_local_save_prefix + local_file;
    #文件名稱
    file_name = url[url.rindex("/")+1:]
    file_path_prefix = url.replace(file_name, "")
    if False == os.path.exists(file_path_prefix):
        os.makedirs(file_path_prefix);
        print("directory don't not makedirs "+  file_path_prefix);
    # 下載OSS文件到本地文件。如果指定的本地文件存在會覆蓋,不存在則新建。
    bucket.get_object_to_file(object_name, download_local_save_prefix+local_file);


if __name__ == '__main__':
    print("start \n");
    # 阿里雲主賬號AccessKey擁有所有API的訪問權限,風險很高。強烈建議您創建並使用RAM賬號進行API訪問或日常運維,請登錄 https://ram.console.aliyun.com 創建RAM賬號。
    auth = oss2.Auth(accesskey_id,accesskey_secret)
    # Endpoint以杭州爲例,其它Region請按實際情況填寫。
    bucket = oss2.Bucket(auth,endpoint , bucket_name)
    bucket2= oss2.Bucket(auth,endpoint , bucket_name2)
    #單個文件夾下載
    root_directory_list(bucket);
    print("end \n");
四、模型訓練概要
將數據集放入項目中,運行u2net_train.py即可。
4.1讀懂訓練部分代碼,其中在step5的地方,我添加了一段處理,用於float和int類型之間轉換
 # 5. Begin training
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)
                ######
                one = torch.ones_like(true_masks)
                zero = torch.zeros_like(true_masks)
                true_masks = torch.where(true_masks>0,one,zero)
                #####
    
                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (10 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in net.named_parameters():
                            tag = tag.replace('/', '.')
                           
                        val_score = evaluate(net, val_loader, device)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
            logging.info(f'Checkpoint {epoch + 1} saved!')
 
4.2 推薦適當投資,採購了autodl進行在線訓練

 

 

通過predict生成模板結果,在Photoshop中進行比對發現邊界已經比較貼合,最終在增強的數據集上,實現了DICE90%的目標。
五、基於OpenCV的Pytorch模型部署方法
 
這裏爲了進行總結,我對分別對目前使用Python和C++下的幾種可行可用的推斷方法進行彙總,並進一步比對。
5.1 (python)使用onnxruntime方法進行推斷
session = onnxruntime.InferenceSession("轉換的onnx文件")
input_name = session.get_inputs()[0].name
label_name = session.get_outputs()[0].name

img_name_list = ['需要處理的圖片']
image = Image.open(img_name_list[0])
w, h = image.size
dataset = SalObjDataset(
    img_name_list=img_name_list,
    lbl_name_list=[],
    transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
)
data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1
)
im = list(data_loader)[0]['image']
inputs_test = im
inputs_test = inputs_test.type(torch.FloatTensor)
with torch.no_grad():
    inputs_test = Variable(inputs_test)
res = session.run([label_name], {input_name: inputs_test.numpy().astype(np.float32)})
result = torch.from_numpy(res[0])
pred = result[:, 0, :, :]
pred = normPRED(pred)
pred = pred.squeeze()
predict_np = pred.cpu().data.numpy()
im = Image.fromarray(predict_np * 255).convert('RGB')
im = im.resize((w, h), resample=Image.BILINEAR)
im.show()
5.2 (python) 使用opencv方法
import os
import argparse

from skimage import io, transform
import numpy as np
from PIL import Image
import cv2 as cv

parser = argparse.ArgumentParser(description='Demo: U2Net Inference Using OpenCV')
parser.add_argument('--input', '-i')
parser.add_argument('--model', '-m', default='u2net_human_seg.onnx')
args = parser.parse_args()

def normPred(d):
    ma = np.amax(d)
    mi = np.amin(d)
    return (d - mi)/(ma - mi)

def save_output(image_name, predict):
    img = cv.imread(image_name)
    h, w, _ = img.shape
    predict = np.squeeze(predict, axis=0)
    img_p = (predict * 255).astype(np.uint8)
    img_p = cv.resize(img_p, (w, h))
    print('{}-result-opencv_dnn.png-------------------------------------'.format(image_name))
    cv.imwrite('{}-result-opencv_dnn.png'.format(image_name), img_p)

def main():
    # load net
    net = cv.dnn.readNet('saved_models/sky_split.onnx')
    input_size = 320 # fixed
    # build blob using OpenCV
    img = cv.imread('test_imgs/sky1.jpg')
    blob = cv.dnn.blobFromImage(img, scalefactor=(1.0/255.0), size=(input_size, input_size), swapRB=True)
    # Inference
    net.setInput(blob)
    d0 = net.forward('output')
    # Norm
    pred = normPred(d0[:, 0, :, :])
    # Save
    save_output('test_imgs/sky1.jpg', pred)

if __name__ == '__main__':
    main()
5.3 (c++)使用libtorch方法

//    std::string strModelPath = "E:/template/u2net_train.pt";
void  bgr_u2net(cv::Mat& image_src, cv::Mat& result, torch::jit::Module& model)
{
    //1.模型已經導入
    auto device = torch::Device("cpu");
    //2.輸入圖片,變換到320
    cv::Mat  image_src1 = image_src.clone();
    cv::resize(image_src1, image_src1, cv::Size(320, 320));
    cv::cvtColor(image_src1, image_src1, cv::COLOR_BGR2RGB);
    // 3.圖像轉換爲Tensor
    torch::Tensor tensor_image_src = torch::from_blob(image_src1.data, { image_src1.rows, image_src1.cols, 3 }, torch::kByte);
    tensor_image_src = tensor_image_src.permute({ 2,0,1 }); // RGB -> BGR互換
    tensor_image_src = tensor_image_src.toType(torch::kFloat);
    tensor_image_src = tensor_image_src.div(255);
    tensor_image_src = tensor_image_src.unsqueeze(0); // 拿掉第一個維度  [3, 320, 320]
    //4.網絡前向計算
    auto src = tensor_image_src.to(device);
    auto pred = model.forward({ src }).toTuple()->elements()[0].toTensor();         //模型返回多個結果,用toTuple,其中elements()[i-1]獲取第i個返回值                                                                                //d1,d2,d3,d4,d5,d6,d7= net(inputs_test) //pred = d1[:,0,:,:]
    auto res_tensor = (pred * torch::ones_like(src));
    res_tensor = normPRED(res_tensor);
    //是否就是Tensor轉換爲圖像
    res_tensor = res_tensor.squeeze(0).detach();
    res_tensor = res_tensor.mul(255).clamp(0, 255).to(torch::kU8); //mul函數,表示張量中每個元素乘與一個數,clamp表示夾緊,限制在一個範圍內輸出
    res_tensor = res_tensor.to(torch::kCPU);
    //5.輸出最終結果
    cv::Mat resultImg(res_tensor.size(1), res_tensor.size(2), CV_8UC3);
    std::memcpy((void*)resultImg.data, res_tensor.data_ptr(), sizeof(torch::kU8) * res_tensor.numel());
    cv::resize(resultImg, resultImg, cv::Size(image_src.cols, image_src.rows), cv::INTER_LINEAR);
    result = resultImg.clone();
}
 
5.4 (c++)使用opencv方法
#include "opencv2/dnn.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
 
#include <iostream>
 
#include "opencv2/objdetect.hpp"
 
using namespace cv;
using namespace std;
using namespace cv::dnn;
 
int main(int argc, char ** argv)
{
    Net net = readNetFromONNX("E:/template/sky_split.onnx");
 
    if (net.empty()) {
        printf("read  model data failure...\n");
        return -1;
    }
 
   // load image data
    Mat frame = imread("e:/template/sky14.jpg");
    Mat blob;
    blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true);
    net.setInput(blob);
    Mat prob = net.forward("output");  
    Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0));
    normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U);
    resize(slice, slice, frame.size());
 
    return 0;
}

 

綜合考慮後,選擇opencv onnx的部署方式
import os
import torch
from unet import UNet  


def main():
    net = UNet(n_channels=3, n_classes=2, bilinear=True)

    net.load_state_dict(torch.load("checkpoints/skyseg0113.pth", map_location=torch.device('cpu')))
    net.eval()

    # --------- model 序列化 ---------
    example = torch.zeros(1, 3, 320, 320) #這裏經過實驗,最大是 example = torch.zeros(1, 3, 411, 411)
    
    torch_script_module = torch.jit.trace(net, example)
    #torch_script_module.save('unet_empty.pt')
    torch.onnx.export(net, example, 'checkpoints/skyseg0113.onnx', opset_version=11)
    print('over')


if __name__ == "__main__":
    main()
 
int main()
{
    //參數和常量準備
    Net net = readNetFromONNX("E:/template/skyseg0113.onnx");
    if (net.empty()) {
        printf("read  model data failure...\n");
        return -1;
    }
    // load image data
    Mat frame = imread("E:\\sandbox/sky4.jpg");
    pyrDown(frame, frame);
    Mat blob;
    blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true);
    net.setInput(blob);
    Mat prob = net.forward("473");//???對於Unet來說,example最大爲(411,411),原理上來說,值越大越有利於分割
    Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0));
    threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV);
    normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U);
    
    Mat mask;
    resize(slice, mask, frame.size());//製作mask
}
通過這種方法,就能夠獲得模型推斷的模板對象,其中“473”是模型訓練過程的層名,由於我們在訓練的過程中沒有指定,所以按照系統自己的名字給出。

 

 

我們可以通過netron的方式查看獲得這裏的名稱。
 
六、結合SeamlessClone等圖像處理方法,實現最終效果
 
int main()
{
    //參數和常量準備
    Net net = readNetFromONNX("E:/template/skyseg0113.onnx");
    if (net.empty()) {
        printf("read  model data failure...\n");
        return -1;
    }
    // load image data
    Mat frame = imread("E:\\sandbox/sky4.jpg");
    pyrDown(frame, frame);
    Mat blob;
    blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true);
    net.setInput(blob);
    Mat prob = net.forward("473");
    Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0));
    threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV);
    normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U);
    
    Mat mask;
    resize(slice, mask, frame.size());//製作mask
    Mat matSrc = frame.clone();
    VP maxCountour = FindBigestContour(mask);
    Rect maxRect = boundingRect(maxCountour);
    if (maxRect.height == 0 || maxRect.width == 0)
        maxRect = Rect(0, 0, mask.cols, mask.rows);//特殊情況
    ////天空替換
    Mat matCloud = imread("E:/template/cloud/cloud1.jpg");
    resize(matCloud, matCloud, frame.size());
    //直接拷貝
    matCloud.copyTo(matSrc, mask);
    imshow("matSrc", matSrc);
    //seamless clone
    matSrc = frame.clone();
    Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2);//中間位置爲藍天的背景位置
    Mat normal_clone;
    Mat mixed_clone;
    Mat monochrome_clone;
    seamlessClone(matCloud, matSrc, mask, center, normal_clone, NORMAL_CLONE);
    seamlessClone(matCloud, matSrc, mask, center, mixed_clone, MIXED_CLONE);
    seamlessClone(matCloud, matSrc, mask, center, monochrome_clone, MONOCHROME_TRANSFER);
    imshow("normal_clone", normal_clone);
    imshow("mixed_clone", mixed_clone);
    imshow("monochrome_clone", monochrome_clone);
    waitKey();
    return 0;
}
在調用seamlessClone()的時候報錯:
報錯原因:可以看seamlessClone源碼(opencv/modules/photo/src/seamless_cloning.cpp),在執行seamlessClone的時候,會先求mask內物體的boundingRect,然後會把這個最小框矩形複製到dst上,矩形中心對齊center
這個過程中可能矩形會超出dst的邊界範圍,就會報上面的roi邊界錯誤。
這裏錯誤的根源應該還是OpenCV 這塊的代碼有問題,其中roi_s不應該適用BoundingRect進行處理。除了進行修改重新編譯,或者直接進行PR解決之外,我們可以採取一些補救的。這裏我採取了2手方法來避免異常:一個是在模板製作的過程中,除了獲得的最大區域之外,主動地將其他區域塗黑,從而保證BoundingRect能夠準確地框選天空區域;二個是在seamlessClone之前,對模板進行異常判斷,對可能出現的情況進程處置。
通過添加opencv代碼,進行系統聯調:

 

 

修改後的代碼爲:
int main()
{
    //參數和常量準備
    Net net = readNetFromONNX("E:/template/skyseg0113.onnx");
    if (net.empty()) {
        printf("read  model data failure...\n");
        return -1;
    }
    vector<string> vecFilePaths;
    getFiles("e:/template/sky", vecFilePaths);
    string strSavePath = "e:/template/sky_change_result";
    for (int index = 0;index<vecFilePaths.size();index++)
    {
        try{
            string strFilePath = vecFilePaths[index];
            string strFileName;
            getFileName(strFilePath, strFileName);
            Mat frame = imread(strFilePath);
            pyrDown(frame, frame);
            Mat blob;
            blobFromImage(frame, blob, 1.0 / 255.0, Size(320, 320), cv::Scalar(), true);
            net.setInput(blob);
            Mat prob = net.forward("473");
            Mat slice(cv::Size(prob.size[2], prob.size[3]), CV_32FC1, prob.ptr<float>(0, 0));
            threshold(slice, slice, 0.1, 1, cv::THRESH_BINARY_INV);
            normalize(slice, slice, 0, 255, NORM_MINMAX, CV_8U);
            Mat mask; 
            resize(slice, mask, frame.size());//製作mask
            Mat matSrc = frame.clone();
            VP maxCountour = FindBigestContour(mask);
            Rect maxRect = boundingRect(maxCountour);
            if (maxRect.height == 0 || maxRect.width == 0)
                maxRect = Rect(0, 0, mask.cols, mask.rows);//特殊情況
            Mat maskRedux(mask.size(), mask.type(), Scalar::all(0));
            Mat roi1 = mask(maxRect);
            Mat roi2 = maskRedux(maxRect);
            roi1.copyTo(roi2);
            ////天空替換
            Mat matCloud = imread("E:/template/cloud/cloud2.jpg");
            resize(matCloud, matCloud, frame.size());
            //直接拷貝
            matCloud.copyTo(matSrc, maskRedux);
            matSrc = frame.clone();
            cv::Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2);//中間位置爲藍天的背景位置
            Rect roi_s = maxRect;
            Rect roi_d(center.x - roi_s.width / 2, center.y - roi_s.height / 2, roi_s.width, roi_s.height);
            if(! (0 <= roi_d.x && 0 <= roi_d.width && roi_d.x + roi_d.width <= matSrc.cols && 0 <= roi_d.y && 0 <= roi_d.height && roi_d.y + roi_d.height <= matSrc.rows))
                center = Point(matSrc.cols / 2, matSrc.rows / 2);//這裏錯誤的根源應該還是OpenCV 這塊的代碼有問題,其中roi_s不應該適用BoundingRect進行處理.所以採取補救的方法
            Mat mixed_clone;
            seamlessClone(matCloud, matSrc, maskRedux, center, mixed_clone, MIXED_CLONE);
            string saveFileName = strSavePath + "/" + strFileName + "_cloud2.jpg";
            imwrite(saveFileName, mixed_clone);
        }
        catch (Exception * e)
        {
            continue;
        }
    }
2022 0312 更新代碼

int main()
{
    Mat src = imread("e:/template/tiantan.jpg");
    Mat matCloud = imread("E:/template/cloud/cloud2.jpg");
    Mat mask = imread("e:/template/tiantanmask2.jpg", 0);
    resize(matCloud, matCloud, src.size());
    resize(mask, mask, src.size());
    Mat matSrc = src.clone();
    Mat board = mask.clone();
    cvtColor(board, board, COLOR_GRAY2BGR);
    //尋找模板最大輪廓
    VP maxCountour = FindBigestContour(mask);
    Rect maxRect = boundingRect(maxCountour);
    //異常處理
    Mat maskCopy = mask.clone();
    copyMakeBorder(maskCopy, maskCopy, 1, 1, 1, 1, BORDER_ISOLATED | BORDER_CONSTANT, Scalar(0));
    Rect roi_s = boundingRect(maskCopy);
    if (roi_s.empty()) return -1;
    cv::Point center = Point((maxRect.x + maxRect.width) / 2, (maxRect.y + maxRect.height) / 2);
    Rect roi_d(center.x - roi_s.width / 2, center.y - roi_s.height / 2, roi_s.width, roi_s.height);
    if (!(0 <= roi_d.x && 0 <= roi_d.width && roi_d.x + roi_d.width <= matSrc.cols && 0 <= roi_d.y && 0 <= roi_d.height && roi_d.y + roi_d.height <= matSrc.rows))
        center = Point(matSrc.cols / 2, matSrc.rows / 2);
    //融合
    Mat normal_clone, mixed_clone, monochrome_clone;
    seamlessClone(matCloud, matSrc, mask, center, normal_clone, NORMAL_CLONE);
    seamlessClone(matCloud, matSrc, mask, center, mixed_clone, MIXED_CLONE);
    seamlessClone(matCloud, matSrc, mask, center, monochrome_clone, MONOCHROME_TRANSFER);
    waitKey();
    return 0;
}

 

七、結果對比和小結
效果是相當不錯的,但是在部署過程中也可能會遇到一些問題;特別是如果用於手機端部署,必然有工具鏈的問題。

 

 

 

 

 

 

 

 

我在hugginface上也實現了可以在線測試的效果。分別是skgseg和skgchange
https://huggingface.co/spaces/jsxyhelu/skyseg

 

 

 
最後,“天空替換”整個問題,只是語義分割的一種應用,結果是美化的圖片。這是價值比較有限的,必須要轉換爲量化的結果,用於定量計數,才能夠推動生產實踐。
此外,關於算法運行效率,也是部署應用的重要環節,在部署實現的時候也需要重點考慮。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章