①欠擬合: 模型太簡單,
②好的擬合: 模型剛剛好
③過擬合: 模型太複雜
操作步驟:
- 加載帶有噪音的二分類數據集(訓練集和驗證集)
- 使用不同神經網絡演示欠擬合和過擬合
- 過擬合應對法:早停法,權重衰減,丟棄法
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';
window.onload = async () => {
const data = getData(200, 2);
tfvis.render.scatterplot(
{ name: '訓練數據' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [2],
activation: "tanh",
// kernelRegularizer: tf.regularizers.l2({ l2: 1 }) // 權重衰減
}));
model.add(tf.layers.dropout({ rate: 0.9 })); // 丟棄法
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
validationSplit: 0.2,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '訓練效果' },
['loss', 'val_loss'],
{ callbacks: ['onEpochEnd'] }
)
});
};
// data.js
// variance:值越大,噪音越大
export function getData(numSamples, variance) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx, variance);
let y = normalRandom(cy, variance);
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}