tensorflow.js示例筆記 - mnist

使用層來進行數字識別,使用tf.layers api訓練模型識別MNIST數據庫中的手寫數字。

index.html

<html>
    <head>
        <title>MNIST</title>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <link rel="stylesheet" href="../shared/tfjs-examples.css"/>

        <style>
            #train {
                margin-top: 10px;
            }

            label {
                display: inline-block;
                width: 250px;
                padding: 6px 0 6px 0;
            }

            .canvases {
                display: inline-block;
            }

            .prediction-canvas {
                width: 100px;
            }

            .pred {
                font-size: 20px;
                line-height: 25px;
                width: 100px;
            }

            .pred-correct {
                background-color: #00cf00;
            }

            .pred-incorrect {
                background-color: red;
            }

            .pred-container {
                display: inline-block;
                width: 100px;
                margin: 10px;
            }

            #train-epochs {
                width: 82px;
                font-size: 14px;
            }
        </style>
    </head>
    <body>
        <div class="tfjs-example-container">
            <section class="title-area">
                <h1>Digit Recognizer with Layers</h1>
                <p class="subtitle">Train a model to recognize handwritten digits from the MNIST database using the tf.layers
                    api.
                </p>
            </section>
            <section>
                <p class="section-head">Description</p>
                <p>
                    This examples lets you train a handwritten digit recognizer using either a Convolutional Neural Network
                    (also known as a ConvNet or CNN) or a Fully Connected Neural Network (also known as a DenseNet).
                </p>
                <p>The MNIST dataset is used as training data.</p>
            </section>
            <section>
                <p class="section-head">Training Parameters</p>
                <div>
                    <label>Model Type:</label>
                    <select id="model-type">
                        <option>ConvNet</option>
                        <option>DenseNet</option>
                    </select>
                </div>

                <div>
                    <label>Number of training epochs:</label>
                    <input id="train-epochs" type="number" value="3" />
                </div>

                <button id="train" disabled>Train Model</button>
            </section>
            <section>
                <p class="section-head">Training Progress</p>
                <p id="status"></p>
                <p id="message"></p>
                <div id="stats">
                    <div class="canvases">
                        <label id="loss-label"></label>
                        <div id="loss-canvas"></div>
                    </div>
                    <div class="canvases">
                        <label id="accuracy-label"></label>
                        <div id="accuracy-canvas"></div>
                    </div>
                    <br />
                    <div class="canvases">
                        <div id="loss-val-canvas"></div>
                    </div>
                    <div class="canvases">
                        <div id="accuracy-val-canvas"></div>
                    </div>
                </div>
            </section>
            <section>
                <p class="section-head">Inference Examples</p>
                <div id="images"></div>
            </section>
        </div>
        <!-- TODO(cais): Decide. DO NOT SUBMIT. -->
        <!-- <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> -->
        <script type="module" src="index.js"></script>
    </body>
</html>

index.js

import * as tf from '@tensorflow/tfjs';
import {IMAGE_H, IMAGE_W, MnistData} from './data';
import * as ui from './ui';

/**
 * Creates a model consisting of only flatten, dense and dropout layers.
 *
 * The model create here has approximately the same number of parameters
 * (~31k) as the convnet created by `createConvModel()`, but is
 * expected to show a significantly worse accuracy after training, due to the
 * fact that it doesn't utilize the spatial information as the convnet does.
 *
 * This is for comparison with the convolutional network above.
 *
 * @returns {tf.Model} An instance of tf.Model.
 */
function createDenseModel() {
    const model = tf.sequential();
    model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]}));
    model.add(tf.layers.dense({units: 42, activation: 'relu'}));
    model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
    return model;
    // 該普通稠密層模型與卷積神經網絡模型相比,參數數量都在32000個左右,基本維持了一個公平的形式。但在最終的損失和準
    // 確率上,普通稠密層模型都比不過後者。雖然與卷積神經網絡的模型相比,準確率差異只有2%左右,但錯誤率卻是幾倍。因此
    // 在處理圖片類業務上,卷積神經網絡模型較稠密層模型有明顯的優勢。
}

/**
 * Creates a convolutional neural network (Convnet) for the MNIST data.
 *
 * @returns {tf.Model} An instance of tf.Model.
 */
function createConvModel() {
    // Create a sequential neural network model. tf.sequential provides an API
    // for creating "stacked" models where the output from one layer is used as
    // the input to the next layer.
    const model = tf.sequential();

    // The first layer of the convolutional neural network plays a dual role:
    // it is both the input layer of the neural network and a layer that performs
    // the first convolution operation on the input. It receives the 28x28 pixels
    // black and white images. This input layer uses 16 filters with a kernel size
    // of 5 pixels each. It uses a simple RELU activation function which pretty
    // much just looks like this: __/
    // 第1層。
    model.add(tf.layers.conv2d({
        inputShape: [IMAGE_H, IMAGE_W, 1],
        // 要應用於輸入數據的滑動卷積過濾器窗口的尺寸。在這裏我們將kernelSize設爲3,以指定方形的3*3卷積窗口。
        kernelSize: 3,
        // 要應用於輸入數據的尺寸爲kernelSize的過濾器窗口數量。在這裏我們將對數據應用16個過濾器。
        filters: 16,
        activation: 'relu'
    }));

    // After the first layer we include a MaxPooling layer. This acts as a sort of
    // downsampling using max values in a region instead of averaging.
    // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks
    // 第2層。
    model.add(tf.layers.maxPooling2d({
        poolSize: 2,
        strides: 2
    }));

    // Our third layer is another convolution, this time with 32 filters.
    // 第3層。
    // 第3層和第4層這兩個層是前兩個層的完全重複(除了conv2d層在其過濾器配置中具有更大的值並且不具有inputShape字段)。
    // 這種由卷積層和極化層組成的幾乎重複的"模體"在convnets中是常見的。它在convnet中扮演着關鍵的角色:特徵的分層提取。
    // 爲了理解它的含義,可以考慮一個訓練過的convnet,它的任務是對圖像中的動物進行分類。在convnet的早期階段,卷積層中
    // 的濾波器(即channel)可以編碼諸如直線、曲線和角等低級幾何特徵。這些低級特徵轉化爲更復雜的特徵,如貓的眼睛、鼻子和
    // 耳朵。在convnet的頂層,一個層可能有對整個cat的存在進行編碼的過濾器。級別越高,表示越抽象,從像素級值中移除的特徵
    // 越多。但是,這些抽象的特徵正是convnet任務實現良好精度所需要的,例如,在圖像中時檢測出貓。此外,這些特徵不是手工制
    // 作的,而是通過有監督的學習並以自動方式從數據中提取的。這是一個典型的有代表性的例子,我們也把它描述爲層-層轉換。
    model.add(tf.layers.conv2d({
        kernelSize: 3,
        filters: 32,
        activation: 'relu'
    }));

    // Max pooling again.
    // 第4層。
    model.add(tf.layers.maxPooling2d({
        poolSize: 2,
        strides: 2
    }));

    // Add another convolution layer.
    // 第5層。
    model.add(tf.layers.conv2d({
        kernelSize: 3,
        filters: 32,
        activation: 'relu'
    }));

    // Now we flatten the output from the 2D filters into a 1D vector to prepare
    // it for input into our last layer. This is common practice when feeding
    // higher dimensional data to a final classification output layer.
    // 第6層。
    // 第6層爲扁平層。它將多維張量“壓縮”爲一維張量,從而保持元素的總數。在我們的例子中,形狀爲[3,3,32]的3D張量被展平
    // 爲1D張量[288](沒有批次維度)。擠壓操作的一個明顯問題是如何對元素排序,因爲原始三維空間中沒有內在的順序。答案是:
    // 我們對元素進行排序,這樣,如果你沿着展開的一維張量中的元素向下看,看看它們的原始索引(來自三維張量)如何變化,最後
    // 一個索引變化最快,倒數第二個索引變化第二快,以此類推,而第一個索引變化最慢。
    // 扁平層用來將輸入“壓平”,即把多維的輸入一維化,常用在從卷積層到全連接層的過渡。扁平層不影響batch的大小。
    // 扁平層中沒有權重,它只是將其輸入展開爲一個長數組。
    model.add(tf.layers.flatten({}));

    // Previous flattened 1D vector will be this layer's input.
    // 第7層。
    // 第7層和第8層我們添加了兩個稠密層,爲什麼要添加兩個而不是一個呢?原因是:添加具有非線性激活的層會增加網絡的容量。
    // 通常多層神經網絡內,必須含有級聯的線性函數和非線性函數,有助於表示能力的增強,這意味着模型容量的增大,預測精度的
    // 提高。在boston-housing項目案例中,我們也做出過相同的總結。https://mapleroyals.com/forum/threads/4th-job-skill-changes.135025
    // 實際上,您可以將convnet看作是由兩個模型堆疊在一起:
    // 1. 包含conv2d、maxPooling2d和flatten層的模型,用於從輸入圖像中提取視覺特徵。
    // 2. 一種多層感知器(MLP),具有兩個密集層,以提取的特徵作爲輸入,並基於它進行數字類預測,這就是這兩個密集層的本
    // 質。在深度學習中,許多模型都利用了特徵提取層的這種模式,然後最終預測使用MLP。在本書的其餘部分中,我們將看到更多
    // 這樣的例子,從音頻信號分類器到自然語言處理。
    model.add(tf.layers.dense({
        units: 64,
        activation: 'relu'
    }));

    // Our last layer is a dense layer which has 10 output units, one for each
    // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually
    // represent numbers, but it's the same idea if you had classes that
    // represented other entities like dogs and cats (two output classes: 0, 1).
    // We use the softmax function as the activation for the output layer as it
    // creates a probability distribution over our 10 classes so their output
    // values sum to 1.
    // 第8層。
    // 這一層將輸出10個值在0-1之間的數字元素,表示對0-9這10個數字的預測概率,它們的和爲1,最大的概率對應的數值爲預測的傾向值。
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax'
    }));

    return model;
}

/**
 * This callback type is used by the `train` function for insertion into the
 * model.fit callback loop.
 *
 * @callback onIterationCallback
 * @param {string} eventType Selector for which type of event to fire on.
 * @param {number} batchOrEpochNumber The current epoch / batch number
 * @param {tf.Logs} logs Logs to append to
 */

/**
 * Compile and train the given model.
 *
 * @param {tf.Model} model The model to train.
 * @param {onIterationCallback} onIteration A callback to execute every 10 batches & epoch end.
 */
async function train(model, onIteration) {
    ui.logStatus('Training model...');

    // We compile our model by specifying an optimizer, a loss function, and a
    // list of metrics that we will use for model evaluation. Here we're using a
    // categorical crossentropy loss, the standard choice for a multi-class
    // classification problem like MNIST digits.
    // The categorical crossentropy loss is differentiable and hence makes
    // model training possible. But it is not amenable to easy interpretation
    // by a human. This is why we include a "metric", namely accuracy, which is
    // simply a measure of how many of the examples are classified correctly.
    // This metric is not differentiable and hence cannot be used as the loss
    // function of the model.
    model.compile({
        // Now that we've defined our model, we will define our optimizer. The
        // optimizer will be used to optimize our model's weight values during
        // training so that we can decrease our training loss and increase our
        // classification accuracy.

        // We are using rmsprop as our optimizer.
        // An optimizer is an iterative method for minimizing a loss function.
        // It tries to find the minimum of our loss function with respect to the
        // model's weight parameters.
        optimizer: 'rmsprop',
        // 損失函數categoricalCrossentropy分類交叉熵,適用於諸如MNIST之類的多分類問題。在iris分類項目案例中,我
        // 們也使用了相同的損失函數,一般情況下,當模型的輸出爲概率分佈時,就會使用此函數。分類交叉熵會生成一個數字,指
        // 示預測向量與真實標籤向量的相似程度。
        loss: 'categoricalCrossentropy',
        // 度量標準函數調用accuracy。假設預測是基於convnet輸出的10個元素中的最大元素進行的,則此函數度量的示例中的一
        // 部分已正確分類。回想一下交叉熵損失和精度度量之間的區別:交叉熵是可微的,因此使基於反向傳播的訓練成爲可能,而精
        // 度度量是不可微的,但卻更容易解釋,因此對於分類問題,這是正確預測在所有預測中所佔的百分比。
        metrics: ['accuracy'],
    });

    // Batch size is another important hyperparameter. It defines the number of
    // examples we group together, or batch, between updates to the model's weights
    // during training. A value that is too low will update weights using too few
    // examples and will not generalize well. Larger batch sizes require more memory
    // resources and aren't guaranteed to perform better.
    // 一般而言,使用較大的批次與較小的批次相比好處是,它對模型的權重產生了更一致且變化較小的漸變更新。但是批次大小越大,訓練
    // 期間就需要更多的內存。您還應該記住,在給定相同數量的訓練數據的情況下,較大的批次大小會導致每個時期的梯度更新數量較少。
    // 因此,如果您使用較大的批量,請確保相應地增加時期數,以免在訓練過程中無意中減少了權重更新的次數。
    const batchSize = 320;

    // Leave out the last 15% of the training data for validation, to monitor
    // overfitting during training.
    // 預留15%的訓練數據用於訓練過程中的驗證。
    const validationSplit = 0.15;

    // Get number of training epochs from the UI.
    const trainEpochs = ui.getTrainEpochs();

    // We'll keep a buffer of loss and accuracy values over time.
    let trainBatchCount = 0;

    const trainData = data.getTrainData();
    const testData = data.getTestData();

    const totalNumBatches = Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs;

    // During the long-running fit() call for model training, we include callbacks,
    // so that we can plot the loss and accuracy values in the page as the training
    // progresses.
    let valAcc;
    await model.fit(
        trainData.xs,  // 特徵輸入。
        trainData.labels,  // 標籤輸入。
        {
            batchSize,  // 每次梯度更新的樣本數。
            validationSplit,  // 末尾15%的數據用於驗證。
            epochs: trainEpochs, // 在訓練數據上的迭代次數。
            callbacks: {
                onBatchEnd: async (batch, logs) => {
                    trainBatchCount++;
                    ui.logStatus(
                        `Training... (${(trainBatchCount / totalNumBatches * 100).toFixed(1)}% complete). ` +
                        'To stop training, refresh or close page.'
                    );
                    // 繪製損失和準確率圖表。
                    ui.plotLoss(trainBatchCount, logs.loss);
                    ui.plotAccuracy(trainBatchCount, logs.acc);

                    // 每10個batch結束時更新測試集的預測結果。
                    if (onIteration && (batch % 10 === 0)) {
                        onIteration('onBatchEnd', batch, logs);
                    }

                    // 主動讓出線程,允許UI在訓練過程中更新。
                    await tf.nextFrame();
                },
                onEpochEnd: async (epoch, logs) => {
                    valAcc = logs.val_acc;
                    // 繪製損失和準確率圖表。
                    ui.plotValLoss(trainBatchCount, logs.val_loss);
                    ui.plotValAccuracy(trainBatchCount, logs.val_acc);

                    // 每個epoch結束時更新測試集的預測結果。
                    if (onIteration) {
                        onIteration('onEpochEnd', epoch, logs);
                    }

                    await tf.nextFrame();
                }
            }
        }
    );

    // 在fit執行完成後(訓練結束),對模型進行評估。
    const testResult = model.evaluate(testData.xs, testData.labels);
    const testAccPercent = testResult[1].dataSync()[0] * 100;
    const finalValAccPercent = valAcc * 100;
    // 更新最後的驗證準確率和評估(測試)準確率。
    ui.logStatus(
        `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` +
        `Final test accuracy: ${testAccPercent.toFixed(1)}%`
    );
}

/**
 * Show predictions on a number of test examples.
 *
 * @param {tf.Model} model The model to be used for making the predictions.
 */
async function showPredictions(model) {
    const testExamples = 100;
    const examples = data.getTestData(testExamples);

    // Code wrapped in a tf.tidy() function callback will have their tensors freed
    // from GPU memory after execution without having to call dispose().
    // The tf.tidy callback runs synchronously.
    tf.tidy(() => {
        // output爲形狀爲[100, 10]的二維張量,第一維對應100個特徵輸入(手寫數字圖片),第二維對應
        // 10個可能的數字的概率。
        const output = model.predict(examples.xs);

        // tf.argMax() returns the indices of the maximum values in the tensor along
        // a specific axis. Categorical classification tasks like this one often
        // represent classes as one-hot vectors. One-hot vectors are 1D vectors with
        // one element for each output class. All values in the vector are 0
        // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The
        // output from model.predict() will be a probability distribution, so we use
        // argMax to get the index of the vector element that has the highest
        // probability. This is our prediction. (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3)
        // dataSync() synchronously downloads the tf.tensor values from the GPU so
        // that we can use them in our normal CPU JavaScript code
        // (for a non-blocking version of this function, use data()).
        // output中的axis 0的每一行,代表某個圖片可能的10個數字的概率,需要找出其中最大的,將其作爲
        // 模型的預測值。
        // argMax()函數返回沿給定軸的最大值的索引。在這種情況下,此軸是第二維,即const axis = 1。
        // argMax()的返回值是形狀爲[100,1]的張量。通過調用 dataSync(),我們將[100,1]形張量轉
        // 換爲長度爲100的Float32Array。然後Array.from()將Float32Array轉換爲一個普通的JavaScript
        // 數組,該數組由100個介於0和9之間的整數組成。此預測數組的含義非常簡單:這是模型對100個輸入圖
        // 像進行分類的結果。在MNIST數據集中,目標標籤恰好與輸出索引完全匹配。因此,我們甚至不需要將數
        // 組轉換爲字符串標籤。預測數組由下一行使用,該行調用一個UI函數,該函數將分類結果與測試圖像一起呈現。
        const axis = 1;
        const labels = Array.from(examples.labels.argMax(axis).dataSync());
        const predictions = Array.from(output.argMax(axis).dataSync());

        // 更新測試集的預測結果,對100個圖標註對它的預測數字和正確性(綠色)。
        ui.showTestResults(examples, predictions, labels);
    });
}

function createModel() {
    let model;
    const modelType = ui.getModelTypeId();
    if (modelType === 'ConvNet') {
        model = createConvModel();
    } else if (modelType === 'DenseNet') {
        model = createDenseModel();
    } else {
        throw new Error(`Invalid model type: ${modelType}`);
    }
    return model;
}

let data;

// When page ready, loads the MNIST data, trains the model, and then shows what
// the model predicted on unseen test data.
window.onload = async () => {
    ui.logStatus('Loading MNIST data...');
    data = new MnistData();
    await data.load();

    ui.logStatus('Creating model...');
    const model = createModel();

    // 輸出模型規格。
    model.summary();
    // __________________________________________________________________________________________
    // Layer (type)                Input Shape               Output shape              Param #
    // ==========================================================================================
    // conv2d_Conv2D1 (Conv2D)     [[null,28,28,1]]          [null,26,26,16]           160
    // __________________________________________________________________________________________
    // max_pooling2d_MaxPooling2D1 [[null,26,26,16]]         [null,13,13,16]           0
    // __________________________________________________________________________________________
    // conv2d_Conv2D2 (Conv2D)     [[null,13,13,16]]         [null,11,11,32]           4640
    // __________________________________________________________________________________________
    // max_pooling2d_MaxPooling2D2 [[null,11,11,32]]         [null,5,5,32]             0
    // __________________________________________________________________________________________
    // conv2d_Conv2D3 (Conv2D)     [[null,5,5,32]]           [null,3,3,32]             9248
    // __________________________________________________________________________________________
    // flatten_Flatten1 (Flatten)  [[null,3,3,32]]           [null,288]                0
    // __________________________________________________________________________________________
    // dense_Dense1 (Dense)        [[null,288]]              [null,64]                 18496
    // __________________________________________________________________________________________
    // dense_Dense2 (Dense)        [[null,64]]               [null,10]                 650
    // ==========================================================================================
    // Total params: 33194
    // Trainable params: 33194
    // Non-trainable params: 0
    // __________________________________________________________________________________________
    //
    // 從上述圖表信息可見該模型總參數爲33194個,其中池化層、扁平層無參數。
    // 在深度學習中,通常模型自身的參數和模型的輸出會佔用顯存。
    // 有參數的層主要包括:卷積層、全連接層、BatchNorm層、Embedding層等。
    // 無參數的層主要包括:多數的激活層(Sigmoid/ReLU)、池化層、Dropout層、扁平層等。

    ui.logStatus('Model ready.');

    // Enable the "Train Model" button.
    document.getElementById('train').removeAttribute('disabled');

    ui.setTrainButtonCallback(async () => {
        ui.logStatus('Starting model training...');
        document.getElementById('images').innerHTML = '';
        await train(model, () => showPredictions(model));
    });
};

 

data.js

import * as tf from '@tensorflow/tfjs';

// This is a helper class for loading and managing MNIST data specifically.
// It is a useful example of how you could create your own data manager class
// for arbitrary data though. It's worth a look :).

// MnistData類封裝了對MNIST灰度圖片訓練集和標籤集的數據加載和預處理過程,爲了使數據對下層算法和框架有友好性,數據的
// 格式和存儲都以二進制緩衝區和類型化數組爲主。這裏的數據處理相對於以JSON爲主的數據處理的普通web應用而言是稍顯複雜的。

// 安東尼·戈德布盧姆(Kaggle的CEO)曾經這樣說過:有人開玩笑說有80%的數據科學家在清理數據,剩下的20%在抱怨清理數據。
// 在實際的數據科學(大數據、機器學習)工作中,清理數據所佔比例比外人想象的要多得多。一般而言,訓練模型通常只佔機器學習
// 或數據科學家工作的一小部分(少於10%)。當下在機器學習的純應用領域上,以成熟框架和算法作爲基礎的開發上,更是如此。

const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

export const IMAGE_H = 28;
export const IMAGE_W = 28;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;

const NUM_CLASSES = 10;  // 0-9。

const NUM_DATASET_ELEMENTS = 65000;  // MNIST精靈圖有65000個子圖,精靈圖的高度也爲65000像素。
const NUM_TRAIN_ELEMENTS = 55000;    // 我們用於訓練的子圖爲55000個。

// A class that fetches the sprited MNIST dataset and provide data as tf.Tensors.
export class MnistData {
    constructor() {

    }

    async load() {
        // Make a request for the MNIST sprited image.
        let img = new Image();

        // 這裏我們用到了canvas是因爲DOM的img對象並沒有提供給我們獲取其像素的API,DOM這個級別的API是無法操作像素的。
        // 只能通過canvas來做一箇中間層,將img內容分批次繪製到canvas上,再從canvas提取出像素數據,看下面for循環部分。
        const canvas = document.createElement('canvas');
        const ctx = canvas.getContext('2d', {
            willReadFrequently: true,
        });

        // 請求手寫數字灰度圖片,並將圖片轉化爲數據格式。
        const imgRequest = new Promise((resolve) => {
            // CORS配置。
            img.crossOrigin = '';

            // naturalWidth和naturalHeight指圖片的原始大小,在計算時可以強制校正圖片尺寸。
            img.onload = () => {
                img.width = img.naturalWidth;     // 圖片本身寬度784。
                img.height = img.naturalHeight;   // 圖片本身長度65000。

                // 初始化一個新的二進制緩衝區,包含整個MMINST精靈圖的每個像素(每一張子圖的每個像素)。它將圖像總數和每張圖像的尺寸和通道數量相乘。
                // 圖片爲PNG格式,所以有rgba4個通道,最後乘以4。
                const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); // 65000 * (28 * 28) * 4。

                const chunkSize = 5000;
                canvas.width = img.width;
                canvas.height = chunkSize;

                // 分13次處理完,每次處理MNIST精靈圖中的5000個子圖。
                // NUM_DATASET_ELEMENTS和img.height一致,都爲65000。每次取5000像素的高度,分13次就能取完。爲什麼它們兩個的一致?我們可以這
                // 麼看:因爲精靈圖寬度爲784像素,784整好是每個子圖的總像素,因爲單個子圖的寬度爲28*28,整好就是784,所以MNIST一橫排的像素總量正
                // 好就是一個子圖的像素總量,從計數上就可以看成一個子圖佔了MNIST的一排,MNIST有多少排就有多少張子圖,MNIST有65000像素高,也就是
                // 有65000排,也就是我們可以算出子圖有65000張,這個數量和實際的子圖數量是一致的。從獲取到的MNIST精靈圖來看,子圖並不是以28*28的
                // 長寬方位比依次存放在MNIST上的,而是以1*784的長寬方位比來存放的,正好就是和我們想的一致:一排就是一張子圖。將每張子圖按照28來切割
                // 並換行,就可以組成一個28*28的方形原始圖,這下我們就能看明白手寫的數字了,雖然這種排列方式利於眼睛的分辨,但顯然不適合像素的提取,
                // 一旦這樣排列,我們在提取子圖像素的時候,就變成了去處理矩陣數據(單個子圖的正方形就可以看做是一個28*28矩陣)。
                // 下面內容涉及到的兩個for循環,其中外層的循環遍歷chunkSize,內層的循環遍歷某個chunkSize下的像素。
                const chunkNum = NUM_DATASET_ELEMENTS / chunkSize; // => 65000 / 5000 = 13。

                for (let i = 0; i < chunkNum; i++) {
                    // drawImage允許我們從一個圖片源上的裁剪某個特定矩形區域並繪製到canvas的特定位置、大小的區域上。
                    // 這裏我們從image上裁剪同image寬度,高度爲chunkSize的區域繪製到canvas上。下一次循環就裁剪下
                    // 一個chunkSize高度的區域。
                    ctx.drawImage(
                        img, // 繪製到上下文的元素,允許任何畫布圖像源,例如:HTMLImageElement、SVGImageElement等。
                        // 以下參數都基於需要繪製到目標上下文的image的矩形(裁剪)選擇框。
                        0,               // 左上角X軸座標。
                        i * chunkSize,   // 左上角Y軸座標。
                        img.width,       // 寬度。
                        chunkSize,       // 高度。
                        // 以下參數都基於image的左上角在目標畫布上的繪製。
                        0,               // 左上角X軸座標。
                        0,               // 左上角Y軸座標。
                        img.width,       // 寬度。
                        chunkSize        // 高度。
                    );

                    // getImageData返回一個ImageData對象,包含canvas給定的矩形區域的像素數據,這裏的話每個循環裏的canvas上線文
                    // 就是15680000個像素(784 * 5000 * 4 = 15680000,4個通道總的需要乘以4)。ImageData結構示例:
                    // {
                    //        width: 100,
                    //        height: 100,
                    //     colorSpace: 'srgb',
                    //        data: Uint8ClampedArray[40000]   // 像素數據。
                    // }
                    const imageData = ctx.getImageData(
                        // 以下參數都基於將要被提取的圖像數據矩形區域。
                        0,               // 左上角x座標。
                        0,               // 左上角y座標。
                        canvas.width,    // 寬度。
                        canvas.height    // 高度。
                    );

                    // datasetBytesView長度爲3920000,即精靈圖中chunkSize個子圖的像素總數量((28 * 28) * 5000)。
                    // 採用Float32類型數組當做數據視圖來操作它,這個視圖即表示了datasetBytesBuffer的一段空間。
                    const datasetBytesView = new Float32Array(  // length => 3920000
                        // 數據。
                        datasetBytesBuffer,
                        // 起始偏移量。第一次循環的時候i爲0,也就是從0開始,即無偏移。第二次循環代的時候i爲1,即偏移從5000個子圖的四通道數
                        // 量以後開始。剩下的以此類推,每次循環都處理完當前chunkSize的數據,循環完13個chunkSize後,即整個MNIST精靈圖的
                        // 數據就在datasetBytesBuffer中了。也就是說每次對chunkSize的迭代處理中,其實都是針對datasetBytesBuffer進行
                        // 處理,往它裏面存放數據,也就是下面的內層for循環中,針對當前chunkSize下的每個像素,對datasetBytesBuffer從
                        // 當前的偏移值i * IMAGE_SIZE * chunkSize * 4開始存儲,共存儲IMAGE_SIZE * chunkSize個數據。
                        i * IMAGE_SIZE * chunkSize * 4,  // => i * (28 * 28) * 5000 * 4。
                        // 元素數量。
                        IMAGE_SIZE * chunkSize           // => (28 * 28) * 5000 = 392000
                    );

                    // 將該區域內總的像素數量除以4得到單通道像素數量。因爲原數據包含了4個通道的數據,除以4才表示單個通道的數據長度,也就是該矩
                    // 形區域內所有像素的按單通道計算的數量。等同於MAGE_SIZE * chunkSize,也就是本次for循環的datasetBytesView的元素長度。
                    const len = imageData.data.length / 4;  //  => 15680000 / 4 = 3920000。

                    // 遍歷當前chunkSize高度區域內的每一個像素(按單通道計算)。通過datasetBytesView視圖將數據放入datasetBytesBuffer
                    // 中的對應位置,來改變datasetBytesBuffer的值。
                    for (let j = 0; j < len; j++) {
                        // All channels hold an equal value since the image is grayscale, so just read the red channel.
                        // 因爲圖片爲灰度化的,所有rgba通道都有相同的值,因此只需要讀取r(紅色)通道的值即可(間隔4個取1次,就能保證每次取的
                        // 都是r通道的值因此乘以了4)。最後將取出的r通道值除以255,得到0-1之間的數。賦值給datasetBytesView視圖對應的下標,
                        // datasetBytesView和datasetBytesBuffer是連通的,視圖被改,datasetBytesBuffer就對應改變了,整個外層for循
                        // 環結束後,datasetBytesBuffer(this.datasetImages)就是整個MINST精靈圖的所有像素的總數據(灰度數據)。
                        // 這裏除以255是爲了避免訓練數據與預測數據之間出現不匹配的情況,因爲我們的MNIST卷積網絡應使用歸一化到0-1之間的圖像張
                        // 量數據進行訓練。
                        datasetBytesView[j] = imageData.data[j * 4] / 255;
                    }
                    // 當上面這個for循環完成後,我們就得到了該chunkSize高度區域的所有的r通道的像素轉成灰度的值。然後再接着遍歷下一個chunkSize
                    // 高度區域,並獲得它的所有r通道的像素數據。
                }

                // chunkSize高度區域都全部遍歷完以後,datasetBytesBuffer裏就已經有了整個MNIST精靈圖的r通道的灰度像素信息。最終datasetImages
                // 的長度爲50960000,即整個MNIST精靈圖的像素總數量(65000 * (28 * 28))。
                // 採用Float32類型數組當做數據視圖來操作它。
                this.datasetImages = new Float32Array(datasetBytesBuffer);  // length => 50960000

                // 圖片DOM用完就回收。
                img = null;

                resolve();
            };

            // 賦值src觸發圖片onload加載回調。
            img.src = MNIST_IMAGES_SPRITE_PATH;
        });

        // 請求標籤數據。
        const labelsRequest = fetch(MNIST_LABELS_PATH);

        // 等待圖片、標籤數據都加載並處理完成。
        const [, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);

        // 將標籤數據(uint8)存在一個二進制緩衝區中(字節數組),並採用Uint8類型數組當做數據視圖來操作它。
        // 標籤數據爲針對每個數字的獨熱編碼集合。以數字6舉例:
        // 樣本  0     1     2     3     4     5     6     7     8     9
        // 標籤  0     0     0     0     0     0     1     0     0     0       // 只有一個是1,即數字6獨熱。
        // 預測  0.10  0.01  0.01  0.01  0.09  0.01  0.71  0.01  0.03  0.02
        // datasetLabels爲一個1維數組,數據長度爲65000 * 10,因爲每組獨熱編碼長度爲10,分別對應10個數字的編碼,只有一個爲1,其他爲0。
        // datasetLabels每10個下標的元素即爲一組針對一個數字的真實值的獨熱編碼。所以下面在劃分的時候乘以了NUM_CLASSES(10)。
        this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

        // Slice the images and labels into train and test sets.
        // 將圖片數據和標籤數據劃分爲訓練集和測試集。
        this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);  // 0到(28 * 28) * 55000。
        this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);      // 上面剩下的。

        this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 0到10 * 55000。
        this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);     // 上面剩下的。
    }

    /**
     * Get all training data as a data tensor and a label tensor.
     *
     * @returns
     *   xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`
     *   labels: The one-hot encoded labels tensor, of shape '[numTrainExamples, 10]'.
     */
    getTrainData() {
        // 這包括表示爲NHWC形狀[N,28、28、1]的4維張量(批次示例的第一維)的輸入MNIST圖像,其中N是圖像總數。
        const xs = tf.tensor4d(
            this.trainImages,
            // NHWC:[55000, 28, 28, 1]。
            [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]
        );

        // 這包括輸入標籤,表示爲形狀爲[N,10]的獨熱編碼2維張量。
        const labels = tf.tensor2d(
            this.trainLabels,
            // NHWC:[55000, 10]。
            [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES]
        );

        // 返回的訓練數據就是標準的特徵集和標籤集配對。
        return {xs, labels};
    }

    /**
     * Get all test data as a data tensor and a labels tensor.
     *
     * @param {number} numExamples Optional number of examples to get. If not provided,
     *                             all test examples will be returned.
     * @returns xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
     *              labels: The one-hot encoded labels tensor, of shape `[numTestExamples, 10]`.
     */
    getTestData(numExamples) {
        let xs = tf.tensor4d(
            this.testImages,
            [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]
        );

        let labels = tf.tensor2d(
            this.testLabels,
            [this.testLabels.length / NUM_CLASSES, NUM_CLASSES]
        );

        if (numExamples != null) {
            // 修正形狀爲[100, 28, 28, 1],即從測試集中取出100張圖片的特徵和標籤。
            xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
            labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
        }

        // 返回的測試數據就是標準的特徵集和標籤集配對。這與訓練集是相似的,只是它不包含在訓練集以內,模型未曾接觸到。
        return {xs, labels};
    }
}

ui.js

import * as tfvis from '@tensorflow/tfjs-vis';

// This is a helper class for drawing loss graphs and MNIST images to the
// window. For the purposes of understanding the machine learning bits, you can
// largely ignore it
const statusElement = document.getElementById('status');
const messageElement = document.getElementById('message');
const imagesElement = document.getElementById('images');

const trainButton = document.getElementById('train');
const modelType = document.getElementById('model-type');
const epochInput = document.getElementById('train-epochs');

const lossLabelElement = document.getElementById('loss-label');
const lossContainer = document.getElementById('loss-canvas');
const lossValContainer = document.getElementById('loss-val-canvas');
const accuracyLabelElement = document.getElementById('accuracy-label');
const accuracyContainer = document.getElementById('accuracy-canvas');
const accuracyValContainer = document.getElementById('accuracy-val-canvas');

export function logStatus(message) {
    statusElement.innerText = message;
}

export function trainingLog(message) {
    messageElement.innerText = `${message}\n`;
}

export function showTestResults(batch, predictions, labels) {
    const testExamples = batch.xs.shape[0];
    imagesElement.innerHTML = '';

    for (let i = 0; i < testExamples; i++) {
        const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);

        const div = document.createElement('div');
        div.className = 'pred-container';

        // 預測的值。
        const pred = document.createElement('div');

        // 判斷預測是否正確。
        const prediction = predictions[i];
        const label = labels[i];
        const correct = prediction === label;

        // 預測正確顯示綠色,錯誤顯示紅色。
        pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
        // 預測的值。
        pred.innerText = `pred: ${prediction}`;

        // 創建一個canvas來展示手寫數字灰度圖片。
        const canvas = document.createElement('canvas');
        canvas.className = 'prediction-canvas';
        draw(image.flatten(), canvas);

        div.appendChild(pred);
        div.appendChild(canvas);

        imagesElement.appendChild(div);
    }
}

const lossValues = [];
const lossValValues = [];
const accuracyValues = [];
const accuracyValValues = [];

export function plotLoss(batch, loss) {
    lossValues.push({
        x: batch,
        y: loss
    });

    tfvis.render.linechart(
        lossContainer,
        {
            values: lossValues,
            series: ['train']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Loss',
            width: 400,
            height: 300,
        }
    );

    lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`;
}

export function plotValLoss(batch, loss) {
    lossValValues.push({
        x: batch,
        y: loss
    });

    tfvis.render.linechart(
        lossValContainer,
        {
            values: lossValValues,
            series: ['validation']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Loss',
            width: 400,
            height: 300,
            seriesColors: ['#f16528']
        }
    );
}

export function plotAccuracy(batch, accuracy) {
    accuracyValues.push({
        x: batch,
        y: accuracy
    });

    tfvis.render.linechart(
        accuracyContainer,
        {
            values: accuracyValues,
            series: ['train']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Accuracy',
            width: 400,
            height: 300,
        }
    );

    accuracyLabelElement.innerText = `last accuracy: ${(accuracy * 100).toFixed(1)}%`;
}

export function plotValAccuracy(batch, accuracy) {
    accuracyValValues.push({
        x: batch,
        y: accuracy
    });

    tfvis.render.linechart(
        accuracyValContainer,
        {
            values: accuracyValValues,
            series: ['validation']
        },
        {
            xLabel: 'Batch Number',
            yLabel: 'Accuracy',
            width: 400,
            height: 300,
            seriesColors: ['#f16528']
        }
    );
}

export function draw(image, canvas) {
    const [width, height] = [28, 28];
    canvas.width = width;
    canvas.height = height;

    const ctx = canvas.getContext('2d');
    const imageData = new ImageData(width, height);
    const data = image.dataSync();

    for (let i = 0; i < height * width; ++i) {
        const j = i * 4;
        imageData.data[j + 0] = data[i] * 255;
        imageData.data[j + 1] = data[i] * 255;
        imageData.data[j + 2] = data[i] * 255;
        imageData.data[j + 3] = 255;
    }

    ctx.putImageData(imageData, 0, 0);
}

export function getModelTypeId() {
    return document.getElementById('model-type').value;
}

export function getTrainEpochs() {
    return Number.parseInt(document.getElementById('train-epochs').value);
}

export function setTrainButtonCallback(train) {
    trainButton.addEventListener('click', async () => {
        // Disable button during the training.
        trainButton.setAttribute('disabled', true);
        modelType.setAttribute('disabled', true);
        epochInput.setAttribute('disabled', true);

        // Start training.
        await train();

        // Release button and reset chart data for next training without refreshing page.
        trainButton.removeAttribute('disabled');
        modelType.removeAttribute('disabled');
        epochInput.removeAttribute('disabled');
        // Rest data array.
        lossValues.splice(0, lossValues.length);
        lossValValues.splice(0, lossValValues.length);
        accuracyValues.splice(0, accuracyValues.length);
        accuracyValValues.splice(0, accuracyValValues.length);
    });
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章