注意事項
1.fit_generator中用的都是整型數字
2.構建checkpoing
3.模型保存的時候注意,如果有自定義層,容易出錯
代碼
#定義模型檢查點
checkpoint = keras.callbacks.ModelCheckpoint(self.save_path, monitor='val_metric_precision', verbose=1,
save_best_only=True, mode='max')
callbacks_list = [checkpoint]
#模型保存
model.fit_generator(G_train,steps_per_epoch=int(self.total_number/self.batch_size),validation_data=G_eval,#不設置steps_per_epoch=
validation_steps=40,epochs=self.epochs,callbacks=callbacks_list)
model.save_weights(self.save_path)#保存模型