python 深度學習 GOPRO數據集的裁剪

數據集樣式

在這裏插入圖片描述

生成數據集樣式

在這裏插入圖片描述

代碼

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import cv2
import PIL
import json, os
import sys
from PIL import Image
import labelme
import labelme.utils as utils
import glob
import itertools


import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import cv2
import PIL
import json, os
import sys
from PIL import Image
import labelme
import labelme.utils as utils
import glob
import itertools

trainDataPath=r'F:\BaiduNetdiskDownload\DATA\train'
testDataPath=r'F:\BaiduNetdiskDownload\DATA\test'

origin_path=r'F:\BaiduNetdiskDownload\U_net_dataset\train'
label_path=r'F:\BaiduNetdiskDownload\U_net_dataset\train'

x_trainpath=r'F:\BaiduNetdiskDownload\deblugData\train\x'
y_trainpath=r'F:\BaiduNetdiskDownload\deblugData\train\y'
x_testpath=r'F:\BaiduNetdiskDownload\deblugData\test\x'
y_testpath=r'F:\BaiduNetdiskDownload\deblugData\test\y'

def cropping(path,Xsavepath,Ysavepath):

    def readFileDocument(path, layer):#read doucument
        layerN = layer
        fileDocument = os.listdir(path)
        subDocument = []
        for document in fileDocument:

            if layerN > 0:
                if (os.path.isdir(path + '/' + document, )):
                    subDocument.append((readFileDocument(path + '/' + document, layerN - 1)))
            else:
                subDocument.append(path + '/' + document)
        return subDocument

    document_location = readFileDocument(path, 1)
    cropping_height = 256
    cropping_width = 256
    number = 0 #total number of dataset to produce
    imgNumber = 10 #every img produce number
    for document in document_location:
        blur = []
        sharp = []

        for subducument in document:
            path,name = os.path.split(subducument)

            if name == 'blur':
                blur.append(subducument)
            if name == 'sharp':
                sharp.append(subducument)

        origin_list= glob.glob(blur[0]+'/*.png')
        origin_list.sort()

        label_list= glob.glob(sharp[0] +'/*.png')
        label_list.sort()

        zipped = itertools.cycle(zip(origin_list,label_list))
        for k in range(len(origin_list)):
            origin_position,label_position=next(zipped)

            origin_img=cv2.imread(origin_position,-1)
            label_img =cv2.imread(label_position,-1)
            shape = origin_img.shape

            i=0
            while i<imgNumber:

                    y=np.random.randint(0, shape[0]-cropping_height)
                    x=np.random.randint(0, shape[1]-cropping_width)

                    cropping1 = origin_img[y:(y + cropping_height),x:(x + cropping_width)]
                    cropping3 = label_img[y:(y + cropping_height),x:(x + cropping_width)]

                    cv2.imwrite(Xsavepath + '/' + str(number) + '.png', cropping1)
                    cv2.imwrite(Ysavepath + '/' + str(number) + '.png', cropping3)

                    i=i+1
                    number += 1
                    print('number',number)
                    if i ==20 :
                        break
if __name__ == '__main__':
    cropping(trainDataPath,x_trainpath,y_trainpath)
    cropping(testDataPath,x_testpath,y_testpath)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章