案例:鳶尾花(iris)分類
操作步驟
- 加載IRIS數據集(訓練集與驗證集)
- 定義模型結構:帶有softmax的多層神經網絡
- 初始化一個神經網絡模型
- 爲神經網絡模型添加兩個層
- 設計層的神經元個數,inputShape,激活函數
- 訓練模型並預測
- 交叉熵損失函數與準確度度量
主要示例代碼:
<!-- index.html -->
<form action="" onsubmit="predict(this); return false;">
花萼長度:<input type="text" name="a"><br>
花萼寬度:<input type="text" name="b"><br>
花瓣長度:<input type="text" name="c"><br>
花瓣寬度:<input type="text" name="d"><br>
<button type="submit">預測</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async() => {
//分別代表訓練集和驗證集的特徵和標籤
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15); // 15%的數據用於驗證集
// xTrain.print();
// yTrain.print();
// xTest.print();
// yTest.print();
// console.log(IRIS_CLASSES);
// 定義模型結構
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape:[xTrain.shape[1]], // 特徵長度:4
activation: 'sigmoid'
}));
model.add(tf.layers.dense({
units: 3,
activation:'softmax'
}));
model.compile({
loss:'categoricalCrossentropy',
optimizer: tf.train.adam(0.1),
metrics: ['accuracy']
});
await model.fit(xTrain, yTrain, {
epochs: 100,
validationData: [xTest, yTest],
callbacks: tfvis.show.fitCallbacks(
{name:'訓練效果'},
['loss','val_loss','acc','val_acc'],
{callbacks:['onEpochEnd']}
)
});
window.predict = (form) => {
const input = tf.tensor([[
form.a.value *1,
form.b.value *1,
form.c.value *1,
form.d.value *1,
]]);
const pred = model.predict(input);
alert(`預測結果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`)
}
};