從推公式到寫代碼--聊聊最小二乘法

本專輯內容的閱讀對象是有一定的高數和線性代數基礎,但是缺少編程訓練的人。


1. 前言

在這一講中,我們來聊聊最小二乘及最小二乘方法求解方程參數問題。希望通過這一講,能讓大家瞭解通用參數求解方法的最小二乘是怎麼工作的,如果大家有python基礎,也希望大家能掌握一般方程的參數求解方法,並能依樣畫葫蘆,解決學習工作中的數學模型參數問題。如果你沒有python基礎也不用擔心,我們後面會有python及python科學計算的系列文章,當然,我希望大家能花一兩週時間學一下python,如果你有MATLAB基礎那麼用不了兩天就能掌握。這裏給大家推薦廖雪峯的博客,少有的高質量python教程博客,通過各種比喻和圖解方法將複雜的編程原理簡潔表示出來,文科生也能看懂高深的計算機原理,相信你也能很快掌握。

2. 什麼是最小二乘

什麼是最小二乘,可以參考知乎問答:最小二乘法的本質是什麼?

它的主要思想就是找到一組參數,使得理論值與觀測值之差的平方和達到最小:

(w,b)=argminE=i=1m(f(xi)yi)2

3. 最小二乘是怎麼求解參數的,以多元線性迴歸爲例

下面我們以多元線性迴歸爲例子,講解最小二乘法求解進行迴歸的參數,需要說明的是,多元線性迴歸求解方法有很多,比如矩陣方法等,最小二乘的思想只是其中的一種。

前面說到,最小二乘的思想是誤差最小化,還記得怎麼求最小值嗎?高數裏面的求導啊,我們對誤差函數求導,令梯度爲零即可。

E(w,b)w=2(wx2i(yib)xi)

E(w,b)b=2(mb(yiwxi))

令其導數爲零得到:

w=i=1myi(xix)i=1mx2i1m(i=1mxi)2

b=1mi=1m(yiwxi)

4. 使用scipy求解最小二乘

在實際工作中,是不是每次都這麼麻煩呢,要自己推導出最後的迭代表達式後再寫代碼計算參數呢?如果你的數學底子好,而且代碼水平也高當然沒問題,但對大部分人的大部分問題是不需要的,一些現成的計算機軟件,對一般的數學問題,可以自動求導,或者使用數值導數方法進行迭代求參。前面演示最小二乘推導過程,是想讓大家對其有深入理解,才能更好理解代碼的原理,而不是依樣畫葫蘆,雖然知道怎麼用卻不知道爲什麼,對於做我們這一行的可不行。

好了,接下來我們通過一個例子來講講怎麼用python的scipy庫進行最小二乘求解。

我們以bass模型爲例講解最小二乘求解函數參數的一般方法。

bass模型最早是由美國Frank Bass提出,是一個用來預測消費品銷售情況的模型,Bass模型對消費者購買新產品的決策時間進行了分析,它認爲新產品的購買者受到外部或內部因素的影響,因此將新產品的潛在使用者分爲兩類,第一類稱爲創新羣體,該羣體易受外部影響,及大衆媒體的影響,第二類稱爲模仿羣體,易受到內部影響,即口碑的影響。模型的核心思想是創新羣體羣體的購買決策獨立於社會系統其他成員,而模仿羣體購買新產品的時間受到社會系統的影響,並且這種影響隨購買人數增加而正價,因爲模仿羣體的購買決策時間受到社會系統成員的影響。

使用bass模型,省略推導過程,最終的形式如下,其中0<=p<=1表示創新羣體系數,0<=q<=1表示模仿羣體系數,m爲潛在購買量,n(t)爲當期銷量。

n(t)=mp(p+q)2e(p+q)t[p+qe(p+q)t]2

從上面的公式可以看出,只要我們知道了m,p,q的值,就可以預測t+1的銷量了。

好了,經過前面的介紹,我們可以開始實用scipy的optimize進行參數求解了。上來就這麼多代碼是有點突兀,沒關係,我希望你能照着代碼敲一邊,有個初步印象,這一講的目標就基本達到了。

# 最小二乘法
from math import e  # 引入自然數e
import numpy as np  # 科學計算庫
import matplotlib.pyplot as plt  # 繪圖庫
from scipy.optimize import leastsq  # 引入最小二乘法算法

# 樣本數據(Xi,Yi),需要轉換成數組(列表)形式
ti = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
yi = np.array([8, 11, 15, 19, 22, 23, 22, 19, 15, 11])


# 需要擬合的函數func :指定函數的形狀,即n(t)的計算公式
def func(params, t):
    m, p, q = params
    fz = (p * (p + q) ** 2) * e ** (-(p + q) * t)  # 分子的計算
    fm = (p + q * e ** (-(p + q) * t)) ** 2  # 分母的計算
    nt = m * fz / fm  # nt值
    return nt


# 誤差函數函數:x,y都是列表:這裏的x,y更上面的Xi,Yi中是一一對應的
# 一般第一個參數是需要求的參數組,另外兩個是x,y
def error(params, t, y):
    return func(params, t) - y


# k,b的初始值,可以任意設定, 一般需要根據具體場景確定一個初始值
p0 = [100, 0.3, 0.3]

# 把error函數中除了p0以外的參數打包到args中(使用要求)
params = leastsq(error, p0, args=(ti, yi))
params = params[0]

# 讀取結果
m, p, q = params
print('m=', m)
print('p=', p)
print('q=', q)

# 有了參數後,就是計算不同t情況下的擬合值
y_hat = []
for t in ti:
    y = func(params, t)
    y_hat.append(y)

# 接下來我們繪製實際曲線和擬合曲線
# 由於模擬數據實在太好,兩條曲線幾乎重合了
fig = plt.figure()
plt.scatter(ti, yi, color='r', label='true')
plt.plot(ti, y_hat, color='b', label='predict')
plt.title('BASS model')
plt.legend()

5. 其他參數求解方法

對Python熟悉的朋友,可以看看我寫的另一篇文章scipy數值優化與參數估計,大致彙總了scipy的optimize數值優化和參數估計的常用方法,比如非線性最小二乘,l-bfgs,共軛梯度等等等,都是參數求解的經典方法。

6. 後話

講到這裏,這一篇就差不多講完了,講到這裏,大家應該能明白基本的參數求解方法了。

在接下來的幾篇文章,我們不打算講怎麼將數學公式轉換程代碼,而是先講講python中的數值計算庫numpy,這是一個很重要的庫,可以說python數據科學體系就是基於numpy建立的。當然我們也不會事無鉅細講numpy怎麼用,而是講講numpy的數據結構,常見的矩陣計算,向量化編程這些,對於理解我們以後的課程是足夠了,如果想更深入學習numpy可以到京東淘寶上面買書看。

至於python基礎,我不打算講,因爲python實在太簡單,這方面的入門到精通之類的數據很多,我相信他們寫得肯定比我好,更有體系。這裏推薦廖雪峯的python技術博客,質量上乘,值的學習。

下次見。
master蘇

7. 參考

scipy數值優化與參數估計

數學優化:找到函數的最優解

Scipy教程 - 優化和擬合庫scipy.optimize

發佈了61 篇原創文章 · 獲贊 163 · 訪問量 42萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章