import cv2
import numpy as np
import xlrd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
def svm_config():
svm = cv2.ml.SVM_create()
svm.setCoef0(0)
svm.setCoef0(0.0)
svm.setDegree(3)
criteria = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 1000, 1e-3)#(cv2.TERM_CRITERIA_EPS, 100, 5e-3)
svm.setTermCriteria(criteria)
svm.setGamma(0.01)#0)
svm.setKernel(cv2.ml.SVM_RBF)#)SVM_LINEAR
svm.setNu(0.5)
svm.setP(5e-3)#0.1)
svm.setC(1)#0.01)
svm.setType(cv2.ml.SVM_EPS_SVR)
return svm
#svm訓練
def svm_train(svm,features,labels):
# svm.train(np.array(features,dtype='float32'),cv2.ml.ROW_SAMPLE,np.array(labels,dtype='float32'))
svm.trainAuto(np.array(features,dtype='float32'),cv2.ml.ROW_SAMPLE,np.array(labels,dtype='float32'))
# svm參數保存
def svm_save(svm, name):
svm.save(name)
def loadTrainDataFromExcel():
ExcelFile = xlrd.open_workbook(r'.\0221_To_Teves訓練數據.xlsx')
#獲取目標EXCEL文件sheet名
print (ExcelFile.sheet_names())
sheet = ExcelFile.sheet_by_name('Sheet1')
#打印sheet的名稱,行數,列數
print (sheet.name,sheet.nrows,sheet.ncols)
#獲取整行或者整列的值
BackgroundDist = sheet.col_values(5)
ForegroundDist = sheet.col_values(1)
ForegroundArea = sheet.col_values(2)
offset = sheet.col_values(3)
BackgroundDist.pop(0)
ForegroundDist.pop(0)
ForegroundArea.pop(0)
offset.pop(0)
data = np.vstack((BackgroundDist, ForegroundDist, ForegroundArea))
data = data.transpose()
data = np.array(data,dtype='float32')
return data, offset
def PlotShow(data, offset):
fig = plt.figure()
ax = Axes3D(fig)
ax.set_title("Input data")
cm_bright = ListedColormap(['#FF0000', '#0000FF'])
p = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=offset)#, cmap=cm_bright)
ax.set_xlabel("Background Dist", color='r')
ax.set_ylabel("Foreground Dist", color='g')
ax.set_zlabel("Foreground Area", color='b')
fig.colorbar(p)
plt.show()
if __name__ == '__main__':
SVR = svm_config()
features, labels = loadTrainDataFromExcel()
PlotShow(features, labels)
svm_train(SVR, features, labels)
svm_save(SVR ,'SVRModel.xml')