import tensorflow as tf
print(tf.__version__)
import os
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
# 定義網絡
net = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=96,kernel_size=11,strides=4,activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(filters=256,kernel_size=5,padding='same',activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
tf.keras.layers.Conv2D(filters=384,kernel_size=3,padding='same',activation='relu'),
tf.keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096,activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4096,activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10,activation='sigmoid')
])
# 獲取數據
from tensorflow.keras.datasets import fashion_mnist
import matplotlib.pyplot as plt
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 熟悉一下,tf.image.resize_with_pad 函數
def resize(x):
x = tf.cast(x, tf.float32)
x = tf.reshape(x, (x.shape[0], x.shape[1], 1))
x = tf.image.resize_with_pad(image=x, target_height=100, target_width=100, method='bilinear')
x = tf.squeeze(x)
return x
img = x_train[0]
img2 = resize(img)
plt.imshow(img)
plt.figure(figsize=(5,5))
plt.show()
plt.figure(figsize=(5,5))
plt.imshow(img2)
plt.show()
# 數據預處理
def data_scale(x, y):
x = tf.cast(x, tf.float32)
x = x / 255.0
x = tf.reshape(x, (x.shape[0], x.shape[1], 1))
x = tf.image.resize_with_pad(image=x, target_height=224,target_width=224)
return x, y
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(20).map(data_scale).batch(256)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)).shuffle(20).map(data_scale).batch(256)
# 定義優化器和損失函數
optimizer = tf.keras.optimizers.SGD(lr=1e-1)
loss = tf.keras.losses.sparse_categorical_crossentropy
net.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
net.fit_generator(train_db, epochs=2, validation_data=test_db)
net.summary(line_length=100)
# net.save()
net.evaluate_generator(generator=test_db)