JAVA-Knn算法-測試集驗證集測試準確率

新手學習-機器學習。 用KNN 算法預測足彩賠率:

當前簡單模型,只用K值。未加權重,初級距離計算公式是 -歐式距離

 

代碼都有註釋:

主體類:

public static void main(String[] args) throws Exception {
        
        
        //導入CSV 工具類
        String path = "F:/deeplearnsuanfa/test.csv";
        List<String> dataList=CvsUtil.getCvs(new File(path));
        dataList.remove(0);
        
        //獲取訓練集 測試集 和 
        Map<String,List<String>> result = ListUtil.trainTestUtil(dataList, 0.2);
        
        //定義 xTrain yTrain  xTest yTest
        Map<String,List> trainAndTestData = ListUtil.getData(result.get("train"), 53);
        
        List<List<Double>> xTrain = trainAndTestData.get("data");
        List<String> yTrain = trainAndTestData.get("lable");
        
        trainAndTestData = ListUtil.getData(result.get("test"), 53);
        
        List<List<Double>> xTest = trainAndTestData.get("data");
        List<String> yTest = trainAndTestData.get("lable");
        
        KnnAlgorithms.getKnn(xTrain, xTest, yTrain, yTest, 5, 53);
}

2:CVS讀取類

    /**
     * 獲取CVS 數據
     * path 路徑
     */
    public static List<String>  getCvs(File file) throws IOException {
         List<String> dataList=new ArrayList<String>();
            BufferedReader br=null;
            try { 
                br = new BufferedReader(new FileReader(file));
                String line = ""; 
                while ((line = br.readLine()) != null) { 
                    dataList.add(line);
                }
            }catch (Exception e) {
            }finally{
                if(br!=null){
                    try {
                        br.close();
                        br=null;
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
     
            return dataList;
        }

3:因爲JAVA 暫時沒有找到和PYTHON numpy 對應list 操作庫。只好用for寫操作list

/**
 * 集合操作工具類
 * @author join
 *
 */
public class ListUtil {
    
    
    /**
     * 分成 測試集和 訓練集
     * list 傳來的CVS
     * number 測試集的佔比
     */
    public static Map<String,List<String>> trainTestUtil(List<String> list,Double number ){
        //隨機打亂數據集
        Collections.shuffle(list);
        
        //大的數據集  裏面包括兩個數據集
        Map<String,List<String>> result = new HashMap<String,List<String>>();
        //訓練集
        List<String> train = new ArrayList<String>();
        //測試集
        List<String> test = new ArrayList<String>();
        
        if(list.size()>0) {
            //得到測試集的 數量下標
            int testNumber = list.size() -(int)(list.size() *number);
            
            for(int i = 0;i<list.size();i++) {
                //則是訓練集
                if(i<testNumber) {
                    train.add(list.get(i));
                }else {
                    test.add(list.get(i));
                }
            }
            
        }else {
            return null;
        }
        result.put("train", train);
        result.put("test", test);
        return result;
    }
    
    /**
     * 得到 每一個 cvs 的 lable 和 訓練數據
     * @param list
     * @return
     */
    public static Map<String,List> getData(List<String> list,int lableNumber) {
        
        Map<String,List> result = new HashMap<String,List>();
        
        List<List<Double>> datas = new ArrayList<List<Double>>();
        List<String> lableData = new ArrayList<String>();
        
        
        for(int i=0;i<list.size();i++) {
            String [] data = list.get(i).split(",");
            List<Double> da = new ArrayList<Double>();
            for(int k=0;k<data.length;k++) {
                if(k == lableNumber) {
                    lableData.add(data[k]);
                }else {
                    if(data[k] != null || !data[k].equals("")) {
                        da.add(Double.valueOf(data[k]));
                    }else {
                        da.add(0.3);
                    }
                }
            }
            //放入數據中
            datas.add(da);
        }
        result.put("data", datas);
        result.put("lable", lableData);
        return result;
    }
    
    /**
     * xTest 和 所有的 xTrain的距離
     * @param xTrain 訓練集
     * @param xTest     測試集
     * @param yTrain    訓練集標籤
     * @throws Exception 
     */
    public static List<List<Map<String,Object>>> allDistance(List<List<Double>> xTrain,List<List<Double>> xTest,List<String> yTrain,int dataNumber) throws Exception {
        List<List<Map<String,Object>>> result = new ArrayList<List<Map<String,Object>>>();
        //FOR  測試集
        for(List<Double> test:xTest) {
            //每個訓練集到測試集的樣本的距離  和 訓練集距離的標籤 
            List<Map<String,Object>> testList = new ArrayList<Map<String,Object>>();
            //FOR 訓練集
            for(int i = 0;i<xTrain.size();i++) {
                //每個訓練集到測試集的樣本的距離  和 訓練集距離的標籤 
                Map<String,Object> map = new HashMap<String,Object>();
                Double euclideanDistance = MathUtil.getEuclideanDistance(xTrain.get(i),test);
                String lable =yTrain.get(i);
                map.put("lable", lable);
                map.put("distance", euclideanDistance);
                testList.add(map);
            }
            //這裏做排序 
            Collections.sort(testList, new Comparator<Map<String, Object>>() {
                public int compare(Map<String, Object> o1, Map<String, Object> o2) {
                    Double name1 = Double.valueOf(o2.get("distance").toString()) ;//name1是從你list裏面拿出來的一個 
                    Double name2 = Double.valueOf(o1.get("distance").toString()) ; //name1是從你list裏面拿出來的第二個name
                    return name1.compareTo(name2);
                }
            });
            result.add(testList);
        }
        return result;
            }

}

 

4:數學公式類:

/**
 * 數據公式 工具類
 * @author join
 *
 */

public class MathUtil{
    
    /**
     * 歐式距離公式
     * 公式:取得參數 互相減 的平方 和  再開放 
     * train 訓練數據集的 一個樣本
     * test  測試數據的一個樣本
     * @return
     * @throws Exception 
     */
    public static Double getEuclideanDistance(List<Double> train,List<Double> test) throws Exception {
        
        Double sum = 0.00;
        
        if(train.size() != test.size()) {
            throw new Exception("兩個集合的大小不一致");
        }
        
        for(int i=0;i<train.size();i++) {
            sum += (train.get(i)-test.get(i)) *(train.get(i)-test.get(i));
        }
        return Math.sqrt(sum);
    }
    

}

 

數據是從網上爬的比較小:代碼粘貼即可用。供學習。也請各位大佬指教一下.

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章