轉自:https://blog.csdn.net/sky_asher/article/details/79717620
在搭LeNet-5 模型時,在卷積層的輸出到全連接層時,使用了reshape將四維的矩陣轉化維2維矩陣時,發生了錯誤:
起初以爲時類型轉換髮生了錯誤,然後演算過後發現並沒有錯誤。然後改了下 訓練數據的輸入格式
# 定義輸入輸出placeholder, **修改前**
x = tf.placeholder(tf.float32,
[None,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
# 定義輸入輸出placeholder。**修改後**
x = tf.placeholder(tf.float32,
[BATCH_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
然後錯誤沒有了,寫了一個簡單的驗證程序,來驗針下placeholder這裏出現的問題。
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, [None,2,2,2],name='x-input')
x_shape = x.get_shape().as_list()
len = x_shape[1] * x_shape[2] * x_shape[3]
x_reshaped = tf.reshape(x, [x_shape[0],len])
y = x_reshaped + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
data = np.arange(2 * 2 * 2 * 2).reshape([2, 2, 2, 2]).astype('float32')
out = sess.run(y, feed_dict={x:data})
print(out)
然後發生瞭如下錯誤,通過黃色標註的字體可以發現時發生在了reshape是發生了錯誤,reshape()函數無法識別None這裏發生的轉換,所以報錯。
所以問題出現在reshape函數這塊,reshap() 函數無法識別轉化列表中的None是多少,這時可以使用python中的自動推導,也就是
# x_reshaped = tf.reshape(x, [x[0], len])
x_reshaped = tf.reshape(x, [-1,len])
這樣完美解決問題,placeholder的shape參數可以爲
x = tf.placeholder(tf.float32,
[None,
mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],
name='x-input')
這樣就可以訓練和測試一起執行,根據需要輸入Bach的大小