深入理解梯度下降

背景:梯度下降於機器學習尤緊要,然已理解如蜻蜓點水,近決心重拾,乃於其刨根究底。

目錄:

  1. 梯度下降作用
  2. 梯度下降公式推導
  3. 梯度下降代碼實現 Python

梯度下降的作用:

機器學習能夠動態求解出一個函數,用這個函數能夠預測出新的結果。我們方程的複雜度大致分成簡單方程,中等方程,複雜方程。

簡單方程:中學知識可以解決了,回想我們中學解一個方程,往往通過消元來解決。

中等難度方程:10-100個未知數,線性代數可以解決。學線性代數時不知其用處。

複雜方程:成千上萬個未知數,那麼梯度下降,計算機來解決。

至此,梯度下降與工程應用,乃利器。

梯度下降的推導:

用常識來解決這個問題,那就是沿着最陡峭的地方下降的最快,假設一個極限,山坡是直立的,那你走一小步,就墜落懸崖了,然後你根本不用邁出第二步。用更加學術的概念說就是梯度下降法。

藍後,梯度下降法又是什麼東西呢?且聽老衲娓娓道來(猥瑣臉)。

簡言之,一個公式解決施主的所有疑惑
\theta_i = \theta_i - \alpha\frac{\partial}{ \partial\theta_i}J(\theta)
說明一下:上面的公式是一個位置更新公式,說白了,就是你每走一步,就記錄一下你現在的位置,也就是等號左邊的 \theta_i ,那這一步之前你在的位置就是等號右邊的 \theta_i ,那你一步走多遠呢?答案是 \alpha ,那你是要朝哪個方向走呢?估計已經猜到了,就是 J(\theta)關於\theta_i的偏導數

現在還有一點小疑惑。J(\theta) 是什麼鬼?現在你可以把它假想爲你在的位置的高度。

現在大概清楚了吧,既有前進的方向,又有前進的距離,很容易聯想到學過的向量。這些向量首尾相連,這個軌跡就是這個方程的曲線圖。畫在圖上大概是這個樣子:

梯度下降法圖解說明

且慢,施主不要走,你只學會了老衲的一成功力,還不足以出師

(呵呵呵)其實,這個公式雖然你能理解了,但是計算機無論如何也想不通,這樣,就算電腦思考到死機也不會產生答案。。。。

現在我要把九陽真經傳授於你:讓計算機也能夠像你一樣去思考這個問題的答案。下面我們把這個公式給通俗化,把它展開成一個可以用計算機語言描述的柿子。

是否還記得上面的假想 J(\theta),現在告訴你,這個假想是錯誤的,因爲它的真實含義不是高度,而是一個關於 \theta 方差的表達式。

它是這樣定義的:J(\theta) = \min\limits_{\theta}\frac{1}{2}\sum\limits_{i=1}^{m}(h_\theta(x^{(i)}-y^{(i)})^2

我來描述一下這個柿子:

首先給定一個 m*n 的矩陣
\begin{matrix} \\x_{11} & ... & x_{1n} \\ . & & . \\ . & . & . \\ . & & . \\ x_{m1} & ... & x_{mn} \end{matrix}

釋義:

\theta:表示需要求解的待定係數

x^{(i)}:表示第 i 行所有的 x

h_\theta(x^{(i)}):表示第 i 行所有的 x 乘以 \theta 後的取值,即 h_\theta(x^{(i)})= \theta_0 + \theta_{1n}x_{1n} + \theta_{2n}x_{2n} + ... + \theta_{in}x_{in} ,表示根據假設的模型計算的 y

y^{(i)}:表示第 i 行對應的真實的 y

J(\theta):表示令方差最小的函數(關於 \theta

=================================

答疑區

  1. 如何理解J(\theta)這個函數

可以簡單的這樣理解,我們要假設的模型最終要和現實世界的模型最好的吻合,這也是我們的初衷,如何來衡量吻合的效果呢?我們用方差來表示吻合的效果,這個其實也叫做損失函數,當我們把損失降低到最小的時候,吻合的效果是最好的。這個和我們一開始提出的下山路徑規劃是一個思路,所以就可以用同一種方法來求解了。其實這個方法就是用來求解最小值問題的。

  1. 那麼爲什麼要走最快的路徑呢?走其他路徑不是也可以到達最低點嗎?

答案是可以,通過其他的路徑也可以到達最低點,在生活中確實也是這樣的,但是根據我們從高中就建立起來的數學觀念,貌似我們只學過兩種求極值的方法,其一是根據曲線的特性,其二是求導。很明顯,這個問題沒有給定的曲線,所以我們只能用第二種方式來求解最值了。

當然如果你發現了一個新的求解極值的方式,也許你就是那個可以改變世界的人。期待你的進一步研究。

  1. 越接近最優解的時候發現圖中的步長越小?

首先,你的發現是正確的。事實是這樣的,這個向量等於 \alpha_i與偏導數 的乘積,雖然我們選擇的 \alpha_i 始終是一個定值,但是越接近最值的時候,這個坡度就會越緩,從而導數的值就越小,也就是乘積變小了,這就是看到步長變小的緣故。

=================================

推導過程

現在大致瞭解了計算機的工作流程。在下面就是公式的推導了。

推導過程

數據量很大如何解決呢?

對於數量級很小的數據集我們可以用上面的方法來進行求解,但是通常情況給出的數據集並不小,我們考慮到計算機的性能,需要換一種解決方案。但是慶幸的是,用到的原理並沒有發生變化。

對於數據集較大的,我們可以從原始數據集中每次訓練時隨機的選擇一部分來進行對真實情況的模擬,雖然會產生一定的誤差,但是這是在準確度和效率之間權衡之後選擇的一個方式。俗話說,魚與熊掌不可兼得。

下面介紹的解決方法是:隨機梯度下降法,用僞代碼來解釋一下:

Repeat{
    for j=1 to m{
        theta_i = theta_i - alpha * J’(theta) # 這個就是上面寫的更新公式
    }
}

這裏隨機選擇的數據集的大小是 m 行。也就是 batch size。

在推導過程中需要用到的概念和公式:

在線性代數中,一個n×n矩陣A的主對角線(從左上方至右下方的對角線)上各個元素的總和被稱爲矩陣A的跡(或跡數),一般記作tr(A)。

  • 公式
  1. tr(AB) = tr(BA) #A、B、C均爲n*n的矩陣
  2. tr(ABC) = tr (CAB) = tr(BCA)
  3. \nabla_Atr(AB)=B^T
  4. a\in\mathbb{R},則 tr(a)=a
  5. \nabla_Atr(ABA^TC)=CAB+C^TAB^T

 

待定係數現在已經不是一個未知數了,根據我們的數據,可以直接對其進行求解了。在使用的時候千萬不要說你還不懂原理,老衲已經把畢生的功力傳輸於你。

3.梯度下降實現代碼

import numpy as np
import pandas as pd
from numpy import *
from pandas import *
import matplotlib.pyplot as plt

a = np.array([[1, 2], [2, 1], [3, 2.5], [4, 3],
              [5, 4], [6, 5], [7, 2.7], [8, 4.5],
              [9, 2]])

m = a.shape[0]
print(m)
print(type(a))
x = a[:, 0]
y = a[:, 1]
plt.scatter(x, y, marker='*', color='r', s=20)

inittheta0 = 0
inittheta1 = 0
iterations = 1500
alpha = 0.01


# 小批量梯度下降
def gradientdescentminibatch(x, y, theta0, theta1, iterations, alpha):
    j_h = np.zeros((iterations, 1))
    for i in range(0, iterations):
        y_hat = theta0 + theta1 * x
        temp0 = theta0 - alpha * ((1 / m) * sum(y_hat - y))
        temp1 = theta1 - alpha * (1 / m) * sum((y_hat - y) * x)
        theta0 = temp0
        theta1 = temp1
        y_hat2 = theta0 + theta1 * x
        aa = sum((y_hat2 - y) ** 2)
        j = aa*(1 / (2 * m))
        j_h[i, :] = j
    return theta0, theta1, j_h


(theta0,theta1,J_h) = gradientdescentminibatch(x,y,inittheta0,inittheta1,iterations,alpha)
print(theta1)
print(theta0)
plt.plot(x,theta0+theta1*x)
plt.title("fittingcurve")
plt.show()
x2=np.arange(iterations)
plt.plot(x2,J_h)
plt.title("costfunction")

plt.show()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章