機器學習之線性分類以及Fisher線性判別
一、什麼是線性分類器和Fisher判別
在機器學習領域,分類的目標是指將具有相似特徵的對象聚集。而一個線性分類器則透過特徵的線性組合來做出分類決定,以達到此種目的。對象的特徵通常被描述爲特徵值,而在向量中則描述爲特徵向量。
線性分類器定義:
Fisher線性判別:
Fisher判別法是判別分析的方法之一,它是藉助於方差分析的思想,利用已知各總體抽取的樣品的p維觀察值構造一個或多個線性判別函數y=l′x其中l= (l1,l2…lp)′,x= (x1,x2,…,xp)′,使不同總體之間的離差(記爲B)儘可能地大,而同一總體內的離差(記爲E)儘可能地小來確定判別係數l=(l1,l2…lp)′。數學上證明判別係數l恰好是|B-λE|=0的特徵根,記爲λ1≥λ2≥…≥λr>0。所對應的特徵向量記爲l1,l2,…lr,則可寫出多個相應的線性判別函數,在有些問題中,僅用一個λ1對應的特徵向量l1所構成線性判別函數y1=l′1x不能很好區分各個總體時,可取λ2對應的特徵向量l′2建立第二個線性判別函數y2=l′2x,如還不夠,依此類推。有了判別函數,再人爲規定一個分類原則(有加權法和不加權法等)就可對新樣品x判別所屬 。
基本介紹:
兩個總體的Fisher判別函數:
多個總體的Fisher判別函數:
判別規則:
二、判別下一模式屬於哪類
三、Fisher判別python代碼的推導
Iris數據集的 Fisher線性分類判斷及準確率計算:
#導入相關庫
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#構建數據集
path=(r'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
#構建樣本類內離散度矩陣
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):
a=Iris1[i,:]-m1
a=np.array([a])
b=a.T
s1=s1+np.dot(b,a)
for i in range(0,30,1):
c=Iris2[i,:]-m2
c=np.array([c])
d=c.T
s2=s2+np.dot(d,c)
for i in range(0,30,1):
a=Iris3[i,:]-m3
a=np.array([a])
b=a.T
s3=s3+np.dot(b,a)
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判別函數以及T
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
#通過判別函數進行判別,求解正確率
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):
x=Iris1[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
kind1=kind1+1
elif g12<0 and g23>0:
newiris2.extend(x)
elif g13<0 and g23<0 :
newiris3.extend(x)
for i in range(30,49):
x=Iris2[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
elif g12<0 and g23>0:
newiris2.extend(x)
kind2=kind2+1
elif g13<0 and g23<0 :
newiris3.extend(x)
for i in range(30,49):
x=Iris3[i,:]
x=np.array([x])
g12=np.dot(w12,x.T)+T12
g13=np.dot(w13,x.T)+T13
g23=np.dot(w23,x.T)+T23
if g12>0 and g13>0:
newiris1.extend(x)
elif g12<0 and g23>0:
newiris2.extend(x)
elif g13<0 and g23<0 :
newiris3.extend(x)
kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("樣本類內離散度矩陣S1:",s1,'\n')
print("樣本類內離散度矩陣S2:",s2,'\n')
print("樣本類內離散度矩陣S3:",s3,'\n')
print("總體類內離散度矩陣Sw12:",sw12,'\n')
print("總體類內離散度矩陣Sw13:",sw13,'\n')
print("總體類內離散度矩陣Sw23:",sw23,'\n')
print('判斷出來的綜合正確率:',correct*100,'%')
樣本類內離散度矩陣S1: [[4.084080000000003 2.9814400000000005 0.5409999999999995
0.4941599999999999]
[2.9814400000000005 3.6879200000000028 -0.025000000000000428
0.5628800000000002]
[0.5409999999999995 -0.025000000000000428 1.0829999999999995 0.19]
[0.4941599999999999 0.5628800000000002 0.19 0.30832000000000004]]
樣本類內離散度矩陣S2: [[8.316120000000005 2.7365199999999987 5.568960000000003
1.7302799999999998]
[2.7365199999999987 3.09192 2.49916 1.3588799999999999]
[5.568960000000003 2.49916 6.258680000000002 2.2232399999999997]
[1.7302799999999998 1.3588799999999999 2.2232399999999997
1.3543200000000004]]
樣本類內離散度矩陣S3: [[14.328471470220745 3.1402832153269435 11.94600583090379
1.3147563515201988]
[3.1402832153269435 3.198721366097457 2.239650145772593
1.2317617659308615]
[11.94600583090379 2.239650145772593 11.600816326530618
1.4958892128279884]
[1.3147563515201988 1.2317617659308615 1.4958892128279884
1.6810578925447726]]
總體類內離散度矩陣Sw12: [[12.4002 5.71796 6.10996 2.22444]
[ 5.71796 6.77984 2.47416 1.92176]
[ 6.10996 2.47416 7.34168 2.41324]
[ 2.22444 1.92176 2.41324 1.66264]]
總體類內離散度矩陣Sw13: [[18.41255147 6.12172322 12.48700583 1.80891635]
[ 6.12172322 6.88664137 2.21465015 1.79464177]
[12.48700583 2.21465015 12.68381633 1.68588921]
[ 1.80891635 1.79464177 1.68588921 1.98937789]]
總體類內離散度矩陣Sw23: [[22.64459147 5.87680322 17.51496583 3.04503635]
[ 5.87680322 6.29064137 4.73881015 2.59064177]
[17.51496583 4.73881015 17.85949633 3.71912921]
[ 3.04503635 2.59064177 3.71912921 3.03537789]]
判斷出來的綜合正確率: 91.66666666666666 %
四、Iris數據集的線性分類以及數據可視化
Iris數據集的線性分類以及數據可視化
#導入相關庫
import numpy as np
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn import preprocessing
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
#讀取數據
df = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=0)
x = df.values[:, :-1]
y = df.values[:, -1]
le = preprocessing.LabelEncoder()
le.fit(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
y = le.transform(y)
#構建線性模型
x = x[:, :2]
x = StandardScaler().fit_transform(x)
lr = LogisticRegression() # Logistic迴歸模型
lr.fit(x, y.ravel()) # 根據數據[x,y],計算迴歸參數
C:\Users\Administrator\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
FutureWarning)
C:\Users\Administrator\Anaconda3\lib\site-packages\sklearn\linear_model\logistic.py:469: FutureWarning: Default multi_class will be changed to 'auto' in 0.22. Specify the multi_class option to silence this warning.
"this warning.", FutureWarning)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
intercept_scaling=1, l1_ratio=None, max_iter=100,
multi_class='warn', n_jobs=None, penalty='l2',
random_state=None, solver='warn', tol=0.0001, verbose=0,
warm_start=False)
#分類及可視化
N, M = 500, 500 # 橫縱各採樣多少個值
x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的範圍
x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的範圍
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成網格採樣點
x_test = np.stack((x1.flat, x2.flat), axis=1) # 測試點
cm_light = mpl.colors.ListedColormap(['#77E0A0', '#FF8080', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
y_hat = lr.predict(x_test) # 預測值
y_hat = y_hat.reshape(x1.shape) # 使之與輸入的形狀相同
plt.pcolormesh(x1, x2, y_hat, cmap=cm_light) # 預測值的顯示
plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=50, cmap=cm_dark)
plt.xlabel('petal length')
plt.ylabel('petal width')
plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
plt.grid()
plt.savefig('iris.png')
plt.show()
#計算準確率
y_hat = lr.predict(x)
y = y.reshape(-1)
result = y_hat == y
acc = np.mean(result)
print('準確度: %.2f%%' % (100 * acc))
準確度: 79.19%
可以看到,鳶尾花數據集共分爲三類,並且不同的數據分佈在不同的類別之中,從而達到線性分類器的效果,但是準確率並不高只有79.19%。
數據可視化
from sklearn import datasets
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
iris = datasets.load_iris()
data1=pd.DataFrame(np.concatenate((iris.data,iris.target.reshape(150,1)),axis=1),
columns=np.append(iris.feature_names,'target'))
data1
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0.0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0.0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0.0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0.0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0.0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2.0 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2.0 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2.0 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 |
150 rows × 5 columns
data=pd.DataFrame(np.concatenate((iris.data,np.repeat(iris.target_names,50).reshape(150,1)),axis=1),columns=np.append(iris.feature_names,'target'))
data=data.apply(pd.to_numeric,errors='ignore')
data
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | virginica |
150 rows × 5 columns
sns.pairplot(data.iloc[:,[0,1,4]],hue='target')
sns.pairplot(data.iloc[:,2:5],hue='target')
<seaborn.axisgrid.PairGrid at 0x200fa6d6388>
plt.scatter(data1.iloc[:,0],data1.iloc[:,1],c=data1.target)
plt.xlabel('sepal length (cm)')
plt.ylabel('sepal width (cm)')
Text(0, 0.5, 'sepal width (cm)')