JavaScript之機器學習3:Tensorflow.js 邏輯迴歸操作

邏輯迴歸操作

  1. 使用預先準備好的腳本生成二分類數據集
  2. 可視化二分類數據集
  3. 定義模型結構:帶有激活函數的單個神經元
    • 初始化一個神經網絡模型
    • 爲神經網絡模型添加層
    • 設計層的神經元個數,inputShape,激活函數
  4. 訓練模型並可視化訓練過程
    • 將訓練數據轉爲Tensor
    • 訓練模型
    • 使用tfvis可視化訓練過程
  5. 進行預測
    • 編寫前端界面輸入待預測數據
    • 使用訓練好的模型進行預測
    • 將輸出的Tensor轉爲普通數據顯示
      在這裏插入圖片描述
      在這裏插入圖片描述
      在這裏插入圖片描述
      演示代碼:
<!-- index.html  -->
<form action="" onsubmit="predict(this);return false;">
    x: <input type="text" name="x">
    y: <input type="text" name="y">
    <button type="submit">預測</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async() => {
    // 1. 使用預先準備好的腳本生成二分類數據集
    const data = getData(400);  // 獲取400個點
    console.log(data);

    // 2. 可視化二分類數據集
    tfvis.render.scatterplot(
        {name:'邏輯迴歸訓練數據'},
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0)
            ]
        }
    );

    // 3.定義模型結構:帶有激活函數的單個神經元
    const model = tf.sequential(); // sequential 連續的
    model.add(tf.layers.dense({   // 添加全連接層 output = activation(dot(input, kernel) + bias)
        units: 1, // 1個神經元
        inputShape: [2],
        activation: 'sigmoid'  // 激活函數,sigmoid作用:把輸出值壓縮到0-1之間
    })); 
    // 設置損失函數和優化器
    model.compile({loss: tf.losses.logLoss, optimizer: tf.train.adam(0.1)});

    // 4. 訓練模型並可視化訓練過程
    const inputs = tf.tensor(data.map(p=>[p.x,p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
        batchSize: 40,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '訓練效果' },
            ['loss']
        )
    });

    window.predict = (form) => {
        const pred = model.predict(tf.tensor([[form.x.value*1, form.y.value*1]]));
        alert(`預測結果:${pred.dataSync()[0]}`)
    }
};
// data.js
export function getData(numSamples) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 1);
    genGauss(-2, -2, 0);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1. 設的越大,範圍越廣
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章