關於tensorflow 中 placeholder 與 reshape的一點坑

轉自:https://blog.csdn.net/sky_asher/article/details/79717620

在搭LeNet-5 模型時,在卷積層的輸出到全連接層時,使用了reshape將四維的矩陣轉化維2維矩陣時,發生了錯誤:

 
這裏寫圖片描述
起初以爲時類型轉換髮生了錯誤,然後演算過後發現並沒有錯誤。然後改了下 訓練數據的輸入格式

 

    # 定義輸入輸出placeholder, **修改前**
    x = tf.placeholder(tf.float32,
                       [None,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')
 # 定義輸入輸出placeholder。**修改後**
    x = tf.placeholder(tf.float32,
                       [BATCH_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')
然後錯誤沒有了,寫了一個簡單的驗證程序,來驗針下placeholder這裏出現的問題。
import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, [None,2,2,2],name='x-input')
x_shape = x.get_shape().as_list()
len = x_shape[1] * x_shape[2] * x_shape[3]
x_reshaped = tf.reshape(x, [x_shape[0],len])

y = x_reshaped + 1
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    data = np.arange(2 * 2 * 2 * 2).reshape([2, 2, 2, 2]).astype('float32')
    out = sess.run(y, feed_dict={x:data})
    print(out)



然後發生瞭如下錯誤,通過黃色標註的字體可以發現時發生在了reshape是發生了錯誤,reshape()函數無法識別None這裏發生的轉換,所以報錯。 

所以問題出現在reshape函數這塊reshap() 函數無法識別轉化列表中的None是多少,這時可以使用python中的自動推導,也就是

# x_reshaped = tf.reshape(x, [x[0], len])
x_reshaped = tf.reshape(x, [-1,len])

這樣完美解決問題,placeholder的shape參數可以爲

x = tf.placeholder(tf.float32,
                       [None,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.IMAGE_SIZE,
                        mnist_inference.NUM_CHANNELS],
                       name='x-input')

這樣就可以訓練和測試一起執行,根據需要輸入Bach的大小

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章