tensorflow手動實現算法之三邏輯迴歸

1.邏輯迴歸算法

邏輯迴歸是日常工作中最常用的算法之一。雖然邏輯迴歸很簡單,出現的年代也比較久遠,但是實現簡單,可解釋性強,一般效果也不會太差,尤其在處理海量數據集的時候具有性能上的巨大優勢,因此邏輯迴歸一般會被用作線上算法的baseline版本之一。

之前邏輯迴歸系列文章
爲什麼要使用logistic函數
損失函數(cost function)詳解
梯度下降訓練方法

2.tensorflow實現

有了上面的理論基礎以後,基於tensorflow我們來實現一把邏輯迴歸,採用的數據集爲mnist。

import tensorflow as tf


def logistic_regression():
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("data/", one_hot=True)

    # tf graph input
    X = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784
    Y = tf.placeholder(tf.float32, [None, 10])

    # Weights
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))

    # model
    pred = tf.nn.softmax(tf.matmul(X, W) + b)

    # loss: 交叉熵損失函數
    loss = tf.reduce_mean(- tf.reduce_sum(Y * tf.log(pred), reduction_indices=1))

    # opt: 梯度下降
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.02).minimize(loss)

    init = tf.global_variables_initializer()

    batch_size = 100

    with tf.Session() as sess:
        sess.run(init)

        for epoch in range(50):
            avg_loss = 0.0
            total_batch = int(mnist.train.num_examples / batch_size)
            for i in range(total_batch):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                _, l = sess.run([optimizer, loss], feed_dict={X: batch_xs, Y: batch_ys})
                avg_loss += l / total_batch

            print("epoch %d loss is: %f" %(epoch, avg_loss))

        print('\n\n')
        print("W is: ", W.eval()[300:320, :])
        print("b is: ", b.eval())
        print("W shape is: ", W.eval().shape)
        print("b shape is: ", b.eval().shape)

        # Test model
        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(Y, 1))

        # Calculate accuracy
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        print("Accuracy: ", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels}))


logistic_regression()

輸出:

epoch 0 loss is: 0.925885
epoch 1 loss is: 0.526170
epoch 2 loss is: 0.454383
epoch 3 loss is: 0.419288
epoch 4 loss is: 0.397402
epoch 5 loss is: 0.381946
epoch 6 loss is: 0.370278
epoch 7 loss is: 0.361055
epoch 8 loss is: 0.353631
epoch 9 loss is: 0.347324
epoch 10 loss is: 0.341956
epoch 11 loss is: 0.337216
epoch 12 loss is: 0.333048
epoch 13 loss is: 0.329546
epoch 14 loss is: 0.326188
epoch 15 loss is: 0.323263
epoch 16 loss is: 0.320566
epoch 17 loss is: 0.318108
epoch 18 loss is: 0.315884
epoch 19 loss is: 0.313779
epoch 20 loss is: 0.311729
epoch 21 loss is: 0.309981
epoch 22 loss is: 0.308338
epoch 23 loss is: 0.306725
epoch 24 loss is: 0.305222
epoch 25 loss is: 0.303862
epoch 26 loss is: 0.302433
epoch 27 loss is: 0.301285
epoch 28 loss is: 0.300108
epoch 29 loss is: 0.298947
epoch 30 loss is: 0.297839
epoch 31 loss is: 0.296805
epoch 32 loss is: 0.295821
epoch 33 loss is: 0.294917
epoch 34 loss is: 0.293935
epoch 35 loss is: 0.293092
epoch 36 loss is: 0.292298
epoch 37 loss is: 0.291464
epoch 38 loss is: 0.290767
epoch 39 loss is: 0.289950
epoch 40 loss is: 0.289225
epoch 41 loss is: 0.288532
epoch 42 loss is: 0.287839
epoch 43 loss is: 0.287262
epoch 44 loss is: 0.286554
epoch 45 loss is: 0.285947
epoch 46 loss is: 0.285375
epoch 47 loss is: 0.284844
epoch 48 loss is: 0.284207
epoch 49 loss is: 0.283784



W is:  [[ 1.30058274e-01 -2.43340924e-01 -2.88325530e-02  2.33772218e-01
  -9.86490175e-02 -9.76923853e-02 -2.35121310e-01  2.35980958e-01
   5.27126603e-02  5.11115305e-02]
 [ 9.32730213e-02 -2.27573335e-01 -4.64867800e-02  1.07791178e-01
  -9.04569626e-02 -7.11493136e-04 -1.82566136e-01  2.22163513e-01
   1.07116044e-01  1.74529087e-02]
 [ 1.88447177e-01 -2.02589765e-01 -4.72486354e-02 -3.71847395e-03
  -1.13012344e-01  7.77632222e-02 -1.59778014e-01  1.04445107e-01
   1.82390571e-01 -2.66976375e-02]
 [ 2.84826338e-01 -1.67852297e-01 -9.68634710e-02 -8.53038132e-02
  -1.96322858e-01  3.31335187e-01 -8.84726346e-02 -7.75983706e-02
   1.82672381e-01 -8.64163712e-02]
 [ 1.11412644e-01 -1.11834183e-01 -9.90691558e-02 -7.97616988e-02
  -1.61673650e-01  6.03716910e-01 -1.09354012e-01 -7.34354481e-02
   3.73539254e-02 -1.17341466e-01]
 [-8.59155580e-02 -3.89751643e-02 -2.86357161e-02 -1.60242077e-02
  -9.84951109e-02  4.94616568e-01 -1.11441180e-01 -2.78114546e-02
  -3.23890485e-02 -5.49230687e-02]
 [-3.36165093e-02 -8.39466415e-03 -1.98007881e-04 -2.61049788e-03
  -2.68098358e-02  1.26275659e-01 -3.57160829e-02 -5.30548953e-03
   3.29373917e-03 -1.69179160e-02]
 [-4.63828037e-04  3.53039541e-05 -3.86474276e-04 -6.74947805e-05
  -7.93245155e-04  4.53931652e-03 -5.11363195e-03  9.23962216e-04
   3.78094311e-03 -2.45485152e-03]
 [-2.32933871e-05 -5.23060453e-06 -1.34871498e-05 -5.77266692e-05
  -6.71111120e-05 -4.80900053e-05 -9.40263817e-06  7.19755713e-04
  -2.79757987e-05 -4.67438367e-04]
 [ 8.65108101e-04  9.09597729e-05 -3.11443349e-04 -1.47864106e-03
  -6.83900435e-03 -2.77064624e-03 -3.86913482e-04  2.58669052e-02
  -8.54679325e-04 -1.41816465e-02]
 [-2.63929088e-03 -2.64492322e-04 -1.53854780e-03  4.31185850e-04
  -3.15029547e-02 -7.15911528e-03 -1.23515935e-03  9.76308361e-02
  -5.08135464e-03 -4.86406647e-02]
 [-1.19008878e-02 -5.28532686e-03 -5.90232015e-03  4.48378287e-02
  -3.87149863e-02 -3.90309207e-02 -2.01594979e-02  1.49421439e-01
  -1.14256218e-02 -6.18363917e-02]
 [-6.87927529e-02 -8.77870712e-03 -4.94896695e-02  2.03535855e-02
  -7.20102340e-02 -3.36355865e-02 -3.02698240e-02  2.09295705e-01
   3.57626490e-02 -2.43041757e-03]
 [-6.52488619e-02 -3.42066772e-02 -1.01321273e-01 -1.07673272e-01
  -6.53655455e-02  4.46031569e-03 -2.43143365e-02  1.28288701e-01
   1.22475803e-01  1.42906830e-01]
 [-1.80723500e-02 -8.89688134e-02 -2.10183084e-01 -2.18472376e-01
  -2.55523417e-02  1.34961814e-01 -6.41731219e-03  5.13334572e-02
   1.82656676e-01  1.98715582e-01]
 [ 8.76079965e-03 -1.15071632e-01 -3.15628499e-01 -3.39576840e-01
   3.33518460e-02  1.49892882e-01  8.10965970e-02  6.31594658e-02
   2.69269615e-01  1.64745867e-01]
 [-8.59805103e-03 -1.37789100e-01 -3.33113641e-01 -3.76729548e-01
   1.30424783e-01  1.63710088e-01  4.69603762e-02 -8.52634199e-03
   2.47313410e-01  2.76347369e-01]
 [-5.59019931e-02 -1.81571841e-01 -3.58628988e-01 -3.46353084e-01
   2.19564497e-01  1.55591249e-01  5.56244776e-02  8.47302899e-02
   2.68048823e-01  1.58896521e-01]
 [-4.31540869e-02 -1.74465016e-01 -4.56743926e-01 -2.85489947e-01
   2.57946610e-01  1.54958680e-01  4.15352397e-02  1.08570904e-01
   1.84646279e-01  2.12190300e-01]
 [-2.23886557e-02 -1.72276974e-01 -4.75257486e-01 -2.02534094e-01
   2.45876372e-01  2.05934808e-01  9.71724372e-03  7.21609220e-02
   1.30991027e-01  2.07777485e-01]]
b is:  [-0.3816464   0.36356112  0.0971376  -0.27018493  0.00226471  1.328009
 -0.104314    0.64620227 -1.4425566  -0.23847117]
W shape is:  (784, 10)
b shape is:  (10,)
Accuracy:  0.9217

3.算法分析

3.1 數據集

爲了方便分析算法,先寫個簡單的方法看看數據都長啥樣

def read_data():
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("data/", one_hot=True)

    print(mnist.train.images[0:10])
    print(mnist.train.labels[0:10])
    print(mnist.train.images.shape)
    print(mnist.train.labels.shape)

read_data()

輸出爲

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]]
(55000, 784)
(55000, 10)

train包含55000個樣本,每個樣本一共有28*28=784維,所以mnist.train.images是個55000 * 784的矩陣。
每個圖片是0-9之間的一個數字,所以總類別是10,one-hot完以後就是個10維向量,只有一維爲1,其餘九維爲0,爲1的那一維對應的index就表示是數字幾。train.label是個55000 * 784的矩陣。

3.2 參數

W: 維度爲784 * 10。
b: 維度爲(10,)
pred = tf.nn.softmax(tf.matmul(X, W) + b)表示用softmax進行預測分類結果,tf.matmul(X, W)的結果爲55000 * 10維,與b相加的時候,b會進行廣播保證與其維度一致。

3.3交叉熵損失函數

重點看看loss函數

loss = tf.reduce_mean(- tf.reduce_sum(Y * tf.log(pred), reduction_indices=1))

Y * tf.log(pred)是交叉熵的定義,Y的維度爲55000 * 10, pred的維度也爲55000 * 10,這一步的結果爲55000 * 10。

- tf.reduce_sum(Y * tf.log(pred), reduction_indices=1)表示在reduction_indices=1的軸上求和。如果將reduction_indices類比成axis參數,這個操作表示要消滅的是內層的維度,即將55000 * 10的矩陣變成55000的數組,相當於對每行求和。

tf.reduce_mean則是求loss的平均值了。

4.優化求解

後面的步驟就都是優化求解了

5.預測

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(Y, 1))
tf.argmax(pred, 1)表示預測值中概率最高的index,就是預測數字爲幾。
tf.argmax(Y, 1))表示真實值中爲1的index(因爲別的位置都爲0,爲1的那個index就是最大值)。
tf.equal會將其變成一個boolean數組

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
就是算最終的準確率了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章