推薦學習網站:http://playground.tensorflow.org/,這個網站就是用Tensorflow.js
寫出來的;
多層神經網絡:XOR邏輯迴歸
同爲0,異爲1
操作步驟:
- 加載XOR數據集
- 定義模型結構:多層神經網絡
- 初始化一個神經網絡模型
- 爲神經網絡模型添加兩個層
- 設計層的神經元個數,inputShape,激活函數
- 訓練模型並預測
演示代碼:
<!-- index.html -->
<form action="" onsubmit="predict(this);return false;">
x: <input type="text" name="x">
y: <input type="text" name="y">
<button type="submit">預測</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async() => {
const data = getData(400); // 獲取400個點
tfvis.render.scatterplot(
{name:'XOR 訓練數據'},
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0)
]
}
);
const model = tf.sequential();
// 設置隱藏層
model.add(tf.layers.dense({
units: 4, // 神經元個數爲4
inputShape: [2], // 長度爲2的一維數組,數據特徵爲2:x,y
activation: 'relu' // 激活函數 非線性
}));
// 設置輸出層
model.add(tf.layers.dense({
units:1, // 只需要輸出一個概率
activation: 'sigmoid' // 輸出0-1之間的概率
}));
// 設置損失函數和優化器
model.compile({
loss:tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x,p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs,labels,{
epochs:10,
callbacks:tfvis.show.fitCallbacks(
{name:'訓練過程'},
['loss']
)
});
window.predict = async (form) => {
const pred = await model.predict(tf.tensor([[form.x.value*1,form.y.value*1]]))
alert(`預測結果:${pred.dataSync()[0]}`);
}
}
// data.js
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 0);
genGauss(-2, -2, 0);
genGauss(-2, 2, 1);
genGauss(2, -2, 1);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}