Tensorflow-mnist 手寫數字識別

1.加載數據MNIST_data,按照tensorflow官網的:

import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
總是報錯,應該查到安裝tensorflow後,input_data.py這個文件在tensorflow的路徑在tutorials下的mnist中,因此按如下import文件:

from tensorflow.examples.tutorials.mnist import input_data
由於在線下載mnist總是顯示下載超時,所以建議在http://yann.lecun.com/exdb/mnist/上直接下載訓練數據,格式爲gz:

train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz:  test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz:  test set labels (4542 bytes)

然後查看input_data.py中的源代碼:

# CVDF mirror of http://yann.lecun.com/exdb/mnist/
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None,
                   source_url=DEFAULT_SOURCE_URL):
  if fake_data:

    def fake():
      return DataSet(
          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  if not source_url:  # empty string check
    source_url = DEFAULT_SOURCE_URL

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   source_url + TRAIN_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   source_url + TRAIN_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   source_url + TEST_IMAGES)
  with gfile.Open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   source_url + TEST_LABELS)
  with gfile.Open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'
        .format(len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]


  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

  return base.Datasets(train=train, validation=validation, test=test)

將source_url關閉(因爲這個的地址DEFAULT_SOURCE_URL='https://storage.googleapis.com/cvdf-datasets/mnist/',其總是打不開),提示直接本地加載mnist data,注意,代碼中MNIST_data/文件夾中需要有下載好的gz格式訓練數據:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True,source_url = False)

這樣就加載完成了。

2.下面是簡單模型softmax regression建立的源代碼:

# coding: utf-8

# In[14]:


from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf


# In[6]:


mnist = input_data.read_data_sets("MNIST_data/", one_hot=True,source_url = False)


# In[23]:


import numpy
print(mnist.train.images.shape)


# In[26]:


x = tf.placeholder("float", [None, 784])    #用浮點數來表示張量形狀,每一張圖展平爲784維的向量
W = tf.Variable(tf.zeros([784,10]))    # W 代表權重
b = tf.Variable(tf.zeros([10]))    # b 偏置量
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])    # 新的佔位符,用於輸入正確值
cross_entropy = -tf.reduce_sum(y_*tf.log(y))    #計算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)    #最小化成本值(交叉熵)
init = tf.initialize_all_variables()    #初始化創建的變量
sess = tf.Session()
sess.run(init)    #在session中啓動模型,變量
# 訓練模型1000次
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})


# In[29]:


correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))    #找最大值的索引值-即結果1
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章