新手學習-機器學習。 用KNN 算法預測足彩賠率:
當前簡單模型,只用K值。未加權重,初級距離計算公式是 -歐式距離
代碼都有註釋:
主體類:
public static void main(String[] args) throws Exception {
//導入CSV 工具類
String path = "F:/deeplearnsuanfa/test.csv";
List<String> dataList=CvsUtil.getCvs(new File(path));
dataList.remove(0);
//獲取訓練集 測試集 和
Map<String,List<String>> result = ListUtil.trainTestUtil(dataList, 0.2);
//定義 xTrain yTrain xTest yTest
Map<String,List> trainAndTestData = ListUtil.getData(result.get("train"), 53);
List<List<Double>> xTrain = trainAndTestData.get("data");
List<String> yTrain = trainAndTestData.get("lable");
trainAndTestData = ListUtil.getData(result.get("test"), 53);
List<List<Double>> xTest = trainAndTestData.get("data");
List<String> yTest = trainAndTestData.get("lable");
KnnAlgorithms.getKnn(xTrain, xTest, yTrain, yTest, 5, 53);
}
2:CVS讀取類
/**
* 獲取CVS 數據
* path 路徑
*/
public static List<String> getCvs(File file) throws IOException {
List<String> dataList=new ArrayList<String>();
BufferedReader br=null;
try {
br = new BufferedReader(new FileReader(file));
String line = "";
while ((line = br.readLine()) != null) {
dataList.add(line);
}
}catch (Exception e) {
}finally{
if(br!=null){
try {
br.close();
br=null;
} catch (IOException e) {
e.printStackTrace();
}
}
}
return dataList;
}
3:因爲JAVA 暫時沒有找到和PYTHON numpy 對應list 操作庫。只好用for寫操作list
/**
* 集合操作工具類
* @author join
*
*/
public class ListUtil {
/**
* 分成 測試集和 訓練集
* list 傳來的CVS
* number 測試集的佔比
*/
public static Map<String,List<String>> trainTestUtil(List<String> list,Double number ){
//隨機打亂數據集
Collections.shuffle(list);
//大的數據集 裏面包括兩個數據集
Map<String,List<String>> result = new HashMap<String,List<String>>();
//訓練集
List<String> train = new ArrayList<String>();
//測試集
List<String> test = new ArrayList<String>();
if(list.size()>0) {
//得到測試集的 數量下標
int testNumber = list.size() -(int)(list.size() *number);
for(int i = 0;i<list.size();i++) {
//則是訓練集
if(i<testNumber) {
train.add(list.get(i));
}else {
test.add(list.get(i));
}
}
}else {
return null;
}
result.put("train", train);
result.put("test", test);
return result;
}
/**
* 得到 每一個 cvs 的 lable 和 訓練數據
* @param list
* @return
*/
public static Map<String,List> getData(List<String> list,int lableNumber) {
Map<String,List> result = new HashMap<String,List>();
List<List<Double>> datas = new ArrayList<List<Double>>();
List<String> lableData = new ArrayList<String>();
for(int i=0;i<list.size();i++) {
String [] data = list.get(i).split(",");
List<Double> da = new ArrayList<Double>();
for(int k=0;k<data.length;k++) {
if(k == lableNumber) {
lableData.add(data[k]);
}else {
if(data[k] != null || !data[k].equals("")) {
da.add(Double.valueOf(data[k]));
}else {
da.add(0.3);
}
}
}
//放入數據中
datas.add(da);
}
result.put("data", datas);
result.put("lable", lableData);
return result;
}
/**
* xTest 和 所有的 xTrain的距離
* @param xTrain 訓練集
* @param xTest 測試集
* @param yTrain 訓練集標籤
* @throws Exception
*/
public static List<List<Map<String,Object>>> allDistance(List<List<Double>> xTrain,List<List<Double>> xTest,List<String> yTrain,int dataNumber) throws Exception {
List<List<Map<String,Object>>> result = new ArrayList<List<Map<String,Object>>>();
//FOR 測試集
for(List<Double> test:xTest) {
//每個訓練集到測試集的樣本的距離 和 訓練集距離的標籤
List<Map<String,Object>> testList = new ArrayList<Map<String,Object>>();
//FOR 訓練集
for(int i = 0;i<xTrain.size();i++) {
//每個訓練集到測試集的樣本的距離 和 訓練集距離的標籤
Map<String,Object> map = new HashMap<String,Object>();
Double euclideanDistance = MathUtil.getEuclideanDistance(xTrain.get(i),test);
String lable =yTrain.get(i);
map.put("lable", lable);
map.put("distance", euclideanDistance);
testList.add(map);
}
//這裏做排序
Collections.sort(testList, new Comparator<Map<String, Object>>() {
public int compare(Map<String, Object> o1, Map<String, Object> o2) {
Double name1 = Double.valueOf(o2.get("distance").toString()) ;//name1是從你list裏面拿出來的一個
Double name2 = Double.valueOf(o1.get("distance").toString()) ; //name1是從你list裏面拿出來的第二個name
return name1.compareTo(name2);
}
});
result.add(testList);
}
return result;
}
}
4:數學公式類:
/**
* 數據公式 工具類
* @author join
*
*/
public class MathUtil{
/**
* 歐式距離公式
* 公式:取得參數 互相減 的平方 和 再開放
* train 訓練數據集的 一個樣本
* test 測試數據的一個樣本
* @return
* @throws Exception
*/
public static Double getEuclideanDistance(List<Double> train,List<Double> test) throws Exception {
Double sum = 0.00;
if(train.size() != test.size()) {
throw new Exception("兩個集合的大小不一致");
}
for(int i=0;i<train.size();i++) {
sum += (train.get(i)-test.get(i)) *(train.get(i)-test.get(i));
}
return Math.sqrt(sum);
}
}
數據是從網上爬的比較小:代碼粘貼即可用。供學習。也請各位大佬指教一下.