這是在coursea的解釋 多元的梯度下降 運算過程
實際也就是求偏導數
本測試用例爲2元
但適用於多元的數據
數據如下
X
1,2,3
Y1,2,3
代碼如下
package hello;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Vector;
public class GradientDescent {
public static Double[] getTheta(List<Double[]> X, Double[] y) {
//初始化長度
int m = y.length;
//初始化theta
Double[] theta = new Double[X.size()];
double a = 0.001;
for (int i = 0; i < theta.length; i++) {
theta[i] = 0.0;
}
//迭代150000次
for (int i = 0; i < 150000; i++) {
//初始化temp,做替換用
Double[] temp = new Double[theta.length];
for (int j = 0; j < temp.length; j++) {
temp[j] = 0.0;
}
for (int j = 0; j < m; j++) {
Double sum = 0.0;
for (int k = 0; k < theta.length; k++) {
//在二元圖形中,這裏相當於k*x+b*1,三元相當於a*x+b*y+c*1,以此類推
sum += theta[k] * X.get(k)[j];
}
sum = sum - y[j];
for (int k = 0; k < theta.length; k++) {
temp[k] += sum * X.get(k)[j];
}
}
for (int j = 0; j < theta.length; j++) {
//一起替換 同時更新
theta[j] -= a / m * temp[j];
}
}
return theta;
}
public static void main(String[] args) throws IOException {
Double[] x1 = GradientDescent.read("C:/Users/ojama/Desktop/testX.txt");
Double[] y = GradientDescent.read("C:/Users/ojama/Desktop/testY.txt");
int m = y.length;
Double[] x0 = new Double[m];
for (int i = 0; i < x0.length; i++) {
x0[i] = 1.0;
}
List<Double[]> X = new Vector<Double[]>();
X.add(x0);
X.add(x1);
Double[] theta = GradientDescent.getTheta(X, y);
for (int i = 0; i < theta.length; i++) {
System.out.println(theta[i]);
}
}
public static Double[] read(String fileName) throws IOException {
File file = new File(fileName);
FileReader fileReader = new FileReader(file);
BufferedReader reader = new BufferedReader(fileReader);
StringBuilder sb = new StringBuilder();
String str = reader.readLine();
while (str != null) {
sb.append(str);
str = reader.readLine();
}
reader.close();
fileReader.close();
String[] X0 = sb.toString().replace(" ", "").split(",");
Double[] x0 = new Double[X0.length];
for (int i = 0; i < x0.length; i++) {
x0[i] = Double.parseDouble(X0[i]);
}
return x0;
}
}
運行結果:
theta0約等於0theta1約等於1;
雖然有一定的誤差 但已經足夠精確了。