tensorflow keras數據集的讀取 fit_generator的使用,以及模型編譯保存

一、數據集的樣式以及讀取函數

數據集以x,y的形式分別保存檢測圖像和標籤,其中X存放png和jpg格式的圖像
讀取的時候用model.fit_generator函數載入數據集,關鍵點則在於生成器的構造

二、步驟

1.製作一個數據生成器代碼
2.使用yield返回值
3.接受值並給予model.fit_generator函數

三、代碼(類)

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses
from tensorflow.keras.callbacks import EarlyStopping

import cv2
import PIL
import json, os
import sys
from PIL import Image
import labelme
import labelme.utils as utils
import glob

以下是整個模型類的一部分 ,方便理解就基本都貼出來了

class Net(): 
	  def __init__(self):#存儲列表
      self.input_width=input_width
      self.input_height=input_height
      self.num_classes=num_classes
      self.train_images=train_images
      self.train_instances=train_instances
      self.val_images=val_images
      self.epochs=epochs
      self.lr=lr
      self.lr_decay=lr_decay
      self.batch_size=batch_size
      self.save_path=save_path
    def build_model():#定義模型的生成樣式
    pass
    
    def train(self):#############訓練的方法
    G_train = self.dataGenerator(mode='training')
    G_eval  = self.dataGenerator(mode='validation')
    model =self.build_mode()#構建模型的方法 具體可以看keras官方文檔
    model.summary()
    model.compile(#模型的編譯
      optimizers=keras.optimizers.Adam(self.lr,self.lr_decay),
      loss = 'categorical_crossentropy',
      metrics=['categorical_accuracy','recall','AUC']
    )
    #使用model.fit_generator載入函數,必須有個數據生成器不斷讀取函數
   model.fit_generator(G_train,5,validation_data=G_eval,validation_steps=5,epochs=self.epochs)
#保存模型
    model.save(self.save_path)



    #數據生成器函數
  def dataGenerator(self,mode):
    if mode =='training':#訓練集
   #讀取文件
      images = glob.glob(self.train_images+'*.jpg')#讀取列表
      images.sort()#排序
      instances= glob.glob(self.train_instances +'*.png')
      instances.sort()
      zipped = inertools.cycle(zip(images,instances))#用zip包裝,cycle循環
      while True :
        x_train=[]#必須定義個空集,使張量量的維度增加一維
        y_train=[]
        for _ in range(self.batch_size):
          img,seg = next(zipped)
          img = cv2. resize(cv2.imread(img,1),(self.input_width,self.input_height))
          seg = keras.utils.to_categorical(cv2.imread(seg,0),num_classes=self.num_classes)
          x_train.append(img)
          y_train.append(seg)
        yield np.array(x_train),np.array(y_train)#使用yield返回值
    if mode == 'validation':#測試集同上,讀取的地方不一樣
      images = glob.glob(self.train_images + '*.jpg')
      images.sort()
      instances = glob.glob(self.train_instances + '*.png')
      instances.sort()
      zipped = inertools.cycle(zip(images, instances))
      while True:
        x_eval = []
        y_eval = []
        for _ in range(self.batch_size):
          img, seg = next(zipped)
          img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))
          seg = keras.utils.to_categorical(cv2.imread(seg, 0), num_classes=self.num_classes)
          x_eval.append(img)
          y_eval.append(seg)
        yield np.array(x_eval), np.array(y_eval)
       
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章