在使用keras 進行圖像分割時,當數據量不大時,我們需要進行圖像增加。在keras 中有專門的函數可以進行增強。這裏進行簡單的介紹一下。
keras 中圖像增強的函數是ImageDataGenerator 類
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-06,
rotation_range=0,
width_shift_range=0.0,
height_shift_range=0.0,
brightness_range=None,
shear_range=0.0,
zoom_range=0.0,
channel_shift_range=0.0,
fill_mode='nearest',
cval=0.0,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0,
dtype=None)
通過實時數據增強生成張量圖像數據批次。數據將不斷循環(按批次)。這裏如果是其他的問題我們可以直接使用這個函數來進行增強很方便。但是在進行圖像分割時需要一個問題,那就是有data 和 label 兩類。你需要同時對這個數據進行處理。比如旋轉什麼的。因此在這裏簡單的介紹一下。
IMAGE_LIB = '../input/2d_images/'
MASK_LIB = '../input/2d_masks/'
SEED=42
all_images = [x for x in sorted(os.listdir(IMAGE_LIB)) if x[-4:] == '.tif']
x_data = np.empty((len(all_images), IMG_HEIGHT, IMG_WIDTH), dtype='float32')
for i, name in enumerate(all_images):
im = cv2.imread(IMAGE_LIB + name, cv2.IMREAD_UNCHANGED).astype("int16").astype('float32')
#im = cv2.resize(im, dsize=(IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_LANCZOS4)
im = (im - np.min(im)) / (np.max(im) - np.min(im))
x_data[i] = im
y_data = np.empty((len(all_images), IMG_HEIGHT, IMG_WIDTH), dtype='float32')
for i, name in enumerate(all_images):
im = cv2.imread(MASK_LIB + name, cv2.IMREAD_UNCHANGED).astype('float32')/255.
#im = cv2.resize(im, dsize=(IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)
y_data[i] = im
這樣就把數據讀入進去了也進行簡單的處理。當然你可以看一下這個數據是什麼樣。
fig, ax = plt.subplots(1,2, figsize = (8,4))
ax[0].imshow(x_data[0], cmap='gray')
ax[1].imshow(y_data[0], cmap='gray')
plt.show()
下面進行數據的增強
def my_generator(x_train, y_train, batch_size):
data_generator = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
rotation_range=10,
zoom_range=0.1).flow(x_train, x_train, batch_size, seed=SEED)
mask_generator = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
rotation_range=10,
zoom_range=0.1).flow(y_train, y_train, batch_size, seed=SEED)
while True:
x_batch, _ = data_generator.next()
y_batch, _ = mask_generator.next()
yield x_batch, y_batch
當然這裏具體的增強方法,你可以參考上面的函數,具體的添加一下方法。
當然你也可以具體看一下是什麼樣。
image_batch, mask_batch = next(my_generator(x_train, y_train, 8))
fix, ax = plt.subplots(8,2, figsize=(8,20))
for i in range(8):
ax[i,0].imshow(image_batch[i,:,:,0])
ax[i,1].imshow(mask_batch[i,:,:,0])
plt.show()
最後在keras中呢整個網絡設計好了以後。可以直接調用
hist = model.fit_generator(my_generator(x_train, y_train, 8),
steps_per_epoch = 200,
validation_data = (x_val, y_val),
epochs=10, verbose=2,
callbacks = [weight_saver, annealer])
這樣就可以調用了。這裏的一些具體的參數。你可以按照自己的要求進行設置。