導入需要的模塊
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
定義需要的參數類
class Config:
def __init__(self):
self.save_path = './model_sin/sin_cos' # 訓練結果保存的路徑
self.lr = 0.001 # 學習步長
self.epoches = 1000 # 迭代的次數
self.samples = 200 # 樣本的數量
self.hidden_units = 200 # 隱藏層神經元的數量
定義樣本類
class Sample:
def __init__(self,samples):
self.xs = np.random.uniform(-np.pi, np.pi, samples) # 使用Numpy隨機生成點
self.xs = sorted(self.xs) # 對生成的點進行排序(否則會亂)
self.ys = np.sin(self.xs), np.cos(self.xs), np.square(self.xs) # 定義需要擬合的三個函數
定義張量類
class Tensors:
def __init__(self,config):
self.x = tf.placeholder(tf.float32,[None],'x') # 定義輸入樣本點x的值
self.y = tf.placeholder(tf.float32,[3,None],'y') # 定義輸入樣本點y的值
x = tf.reshape(self.x,[-1,1]) # 將所輸入的標量轉換爲向量
x = tf.layers.dense(x, config.hidden_units, tf.nn.relu) # 使用全連接進行函數
self.y_predict = tf.layers.dense(x,3) # 全連接獲取預測的y值
y = tf.transpose(self.y) # 對輸入的y值進行轉置,方便進行後續操作
self.loss = tf.reduce_mean(tf.square(self.y_predict - y)) # 獲取損失函數的值
self.lr = tf.placeholder(tf.float32, [], 'lr') # 定義學習步長
opt = tf.train.AdamOptimizer(config.lr) # 梯度下降優化器,進行擬合
self.train_op = opt.minimize(self.loss) # 獲取dertx的值
self.loss = tf.sqrt(self.loss) # 對獲取的損失值進行開平方操作,使誤差減少
定義應用類
應用類應該包含:
- 樣本的訓練方法
- 樣本的預測方法
class SinApp:
def __init__(self,config):
self.config = config
self.ts = Tensors(config)
self.session = tf.Session()
self.saver = tf.train.Saver()
try:
self.saver.restore(self.session, config.save_path)
print(f"restore model from {config.save_path} successfully")
except:
print(f"fail to restore the model from {config.save_path}")
self.session.run(tf.global_variables_initializer())
def train(self): # 訓練方法
samples = Sample(self.config.samples)
cfg = self.config
ts = self.ts
for _ in range(cfg.epoches):
self.session.run(ts.train_op,{ts.x:samples.xs, ts.y : samples.ys, ts.lr:cfg.lr})
self.save()
return samples.xs,samples.ys
def predict(self): # 預測方法
samples = Sample(400)
ys = self.session.run(self.ts.y_predict,{self.ts.x:samples.xs})
return samples.xs, ys
def save(self):
self.saver.save(self.session, self.config.save_path)
def close(self):
self.session.close()
定義主方法打印擬合的圖像
if __name__ == "__main__":
app = SinApp(Config())
xs_train, ys_train = app.train()
xs_predict, ys_presict = app.predict()
ys_train = np.transpose(ys_train)
plt.plot(xs_train, ys_train)
plt.plot(xs_predict, ys_presict)
plt.legend(['sin', 'cos', 'square', 'predict_sin', 'predict_cos', 'predict_square'])
plt.show()
運行結果爲
樣本的訓練結果