在Ignite中使用線性迴歸算法 頂 原 薦

在本系列前面的文章中,簡單介紹了一下Ignite的機器學習網格,下面會趁熱打鐵,結合一些示例,深入介紹Ignite支持的一些機器學習算法。

如果要找合適的數據集,會發現可用的有很多,但是對於線性迴歸來說,一個非常好的備選數據集就是房價,可以非常方便地從UCI網站獲取合適的數據

在本文中會訓練一個線性迴歸模型,並且計算R2得分。

需要先準備一些數據,並且要將數據轉換成Ignite支持的格式,這通常是數據科學家需要花時間做的事。

首先,需要獲取原始數據並將其拆分成訓練數據(80%)和測試數據(20%)。Ignite暫時還不支持專用的數據拆分,路線圖中的未來版本會支持這個功能。但是就目前來說有許多可用的免費和開源工具可以執行這樣的數據拆分,或者也可以用一種Ignite支持的編程語言自己編寫這種代碼。在本文中會使用下面自己編寫的代碼來實現此任務:

from sklearn import datasets
import pandas as pd

# Load Boston housing dataset.
boston_dataset = datasets.load_boston()
x = boston_dataset.data
y = boston_dataset.target

# Split it into train and test subsets.
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=23)

# Save train set.
train_ds = pd.DataFrame(x_train, columns=boston_dataset.feature_names)
train_ds["TARGET"] = y_train
train_ds.to_csv("boston-housing-train.csv", index=False, header=None)
# Save test set.
test_ds = pd.DataFrame(x_test, columns=boston_dataset.feature_names)
test_ds["TARGET"] = y_test
test_ds.to_csv("boston-housing-test.csv", index=False, header=None)

# Train linear regression model.
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(x_train, y_train)

# Score result model.
lr.score(x_test, y_test)

這段代碼從UCI網站上獲取可用的數據集,執行了數據的拆分,然後計算了R2得分。返回值爲0.745021053016975,或者爲74.5%,之後會將此值與Ignite的進行對比。

當訓練和測試數據準備好之後,就可以寫應用了,本文的算法是:

  1. 讀取訓練數據和測試數據;
  2. 在Ignite中保存訓練數據和測試數據;
  3. 使用訓練數據擬合線性迴歸模型;
  4. 將模型應用於測試數據;
  5. 確定模型的R2得分。

由於數據集非常小,可以將其加載到標準Java數據結構中,並直接從Java程序中運行線性迴歸。或者,也可以將數據加載到Ignite存儲中,然後對存儲的數據進行線性迴歸。使用Ignite存儲的優點是數據將分佈在整個集羣中,因此將執行分佈式訓練。對於大規模數據集,使用Ignite存儲就會有很大的好處。在本例中將把數據加載到Ignite存儲中。

讀取訓練數據和測試數據

需要讀取兩個CSV文件,一個是訓練數據,一個是測試數據。通過下面的代碼,可以從CSV文件中讀取數據:

private static void loadData(String fileName, IgniteCache<Integer, HouseObservation> cache)
        throws FileNotFoundException {

   Scanner scanner = new Scanner(new File(fileName));

   int cnt = 0;
   while (scanner.hasNextLine()) {
      String row = scanner.nextLine();
      String[] cells = row.split(",");
      double[] features = new double[cells.length - 1];

      for (int i = 0; i < cells.length - 1; i++)
         features[i] = Double.valueOf(cells[i]);
      double price = Double.valueOf(cells[cells.length - 1]);

      cache.put(cnt++, new HouseObservation(features, price));
   }
}

該代碼簡單地一行行的讀取數據,然後對於每一行,使用CSV的分隔符拆分出字段,每個字段之後將轉換成double類型並且存入Ignite。

將訓練數據和測試數據存入Ignite

前面的代碼將數據存入Ignite,要使用這個代碼,首先要創建Ignite存儲,如下:

IgniteCache<Integer, HouseObservation> trainData = ignite.createCache("BOSTON_HOUSING_TRAIN");
IgniteCache<Integer, HouseObservation> testData = ignite.createCache("BOSTON_HOUSING_TEST");

使用訓練數據創建線性迴歸模型

數據存儲之後,可以像下面這樣創建訓練器:

DatasetTrainer<LinearRegressionModel, Double> trainer = new LinearRegressionLSQRTrainer();

然後擬合訓練數據,如下:

LinearRegressionModel mdl = trainer.fit(
   ignite,
   trainData,
   (k, v) -> v.getFeatures(),  
// Feature extractor.

   (k, v) -> v.getPrice()
// Label extractor.

Ignite將數據保存爲鍵-值(K-V)格式,因此上面的代碼使用了值部分,目標值是Price,而特徵位於其他列中。

將模型應用於測試數據

下一步,就可以用訓練好的線性模型測試測試數據了,在Ignite的機器學習路線圖中,有計劃提供內置的得分計算器,但是就目前來說,可以這樣做:

double meanPrice = getMeanPrice(testData);
double u = 0, v = 0;

try (QueryCursor<Cache.Entry<Integer, HouseObservation>> cursor = testData.query(new ScanQuery<>())) {
   for (Cache.Entry<Integer, HouseObservation> testEntry : cursor) {
      HouseObservation observation = testEntry.getValue();

      double realPrice = observation.getPrice();
      double predictedPrice = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

      u += Math.pow(realPrice - predictedPrice, 2);
      v += Math.pow(realPrice - meanPrice, 2);
   }
}

這裏計算的是殘差平方和(U)和總平方和(V)。

確定模型的R2得分

可以發現,R2的值爲1 - u / v:

double score = 1 - u / v;

System.out.println("Score : " + score);

輸出值爲0.7450194305206714,或者74.5%,這與之前的值相同。

總結

Apache Ignite提供了一個機器學習算法庫。通過線性迴歸示例,可以看到創建模型、測試模型和確定模型的R2得分的簡單性,也可以用這個模型來做預測。

目前,可用的機器學習工具有很多,但它們不能多節點擴展,只能處理少量數據。相比之下,Ignite所帶來的好處是它有能力擴展下面兩種能力:

  1. 集羣的大小(成百上千臺機器)
  2. 存儲的數據量(GB、TB甚至PB級數據)

因此,Ignite可以大規模地運行機器學習。它可以以分佈式處理的方式,對大數據進行真正的機器學習管理。

在機器學習系列的下一篇中,將研究另一種機器學習算法。敬請期待!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章