【原】用excel中的數據進行svr訓練


 
import cv2
import numpy as np
import xlrd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap

def svm_config():
    svm = cv2.ml.SVM_create()

    svm.setCoef0(0)

    svm.setCoef0(0.0)

    svm.setDegree(3)

    criteria = (cv2.TERM_CRITERIA_MAX_ITER + cv2.TERM_CRITERIA_EPS, 1000, 1e-3)#(cv2.TERM_CRITERIA_EPS, 100, 5e-3)
    svm.setTermCriteria(criteria)

    svm.setGamma(0.01)#0)
    svm.setKernel(cv2.ml.SVM_RBF)#)SVM_LINEAR
    svm.setNu(0.5)

    svm.setP(5e-3)#0.1)
    svm.setC(1)#0.01)
    svm.setType(cv2.ml.SVM_EPS_SVR)

return svm

#svm訓練
def svm_train(svm,features,labels):
    # svm.train(np.array(features,dtype='float32'),cv2.ml.ROW_SAMPLE,np.array(labels,dtype='float32'))
    svm.trainAuto(np.array(features,dtype='float32'),cv2.ml.ROW_SAMPLE,np.array(labels,dtype='float32'))

# svm參數保存
def svm_save(svm, name):
    svm.save(name)

def loadTrainDataFromExcel():
    ExcelFile = xlrd.open_workbook(r'.\0221_To_Teves訓練數據.xlsx')
#獲取目標EXCEL文件sheet名
    print (ExcelFile.sheet_names())

    sheet = ExcelFile.sheet_by_name('Sheet1')
#打印sheet的名稱,行數,列數
    print (sheet.name,sheet.nrows,sheet.ncols)
#獲取整行或者整列的值
    BackgroundDist = sheet.col_values(5)

    ForegroundDist = sheet.col_values(1)

    ForegroundArea = sheet.col_values(2)

    offset = sheet.col_values(3)

    BackgroundDist.pop(0)

    ForegroundDist.pop(0)

    ForegroundArea.pop(0)

    offset.pop(0)

    data = np.vstack((BackgroundDist, ForegroundDist, ForegroundArea))

    data = data.transpose()

    data = np.array(data,dtype='float32')

return data, offset

def PlotShow(data, offset):
    fig = plt.figure()

    ax = Axes3D(fig)

    ax.set_title("Input data")
cm_bright = ListedColormap(['#FF0000', '#0000FF'])

    p = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=offset)#, cmap=cm_bright)
    ax.set_xlabel("Background Dist", color='r')

    ax.set_ylabel("Foreground Dist", color='g')

    ax.set_zlabel("Foreground Area", color='b')

    fig.colorbar(p)

    plt.show()

if __name__ == '__main__':
    SVR = svm_config()

    features, labels = loadTrainDataFromExcel()
PlotShow(features, labels)
svm_train(SVR, features, labels)
svm_save(SVR ,'SVRModel.xml')

 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章