使用層來進行數字識別,使用tf.layers api訓練模型識別MNIST數據庫中的手寫數字。
index.html
<html> <head> <title>MNIST</title> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="stylesheet" href="../shared/tfjs-examples.css"/> <style> #train { margin-top: 10px; } label { display: inline-block; width: 250px; padding: 6px 0 6px 0; } .canvases { display: inline-block; } .prediction-canvas { width: 100px; } .pred { font-size: 20px; line-height: 25px; width: 100px; } .pred-correct { background-color: #00cf00; } .pred-incorrect { background-color: red; } .pred-container { display: inline-block; width: 100px; margin: 10px; } #train-epochs { width: 82px; font-size: 14px; } </style> </head> <body> <div class="tfjs-example-container"> <section class="title-area"> <h1>Digit Recognizer with Layers</h1> <p class="subtitle">Train a model to recognize handwritten digits from the MNIST database using the tf.layers api. </p> </section> <section> <p class="section-head">Description</p> <p> This examples lets you train a handwritten digit recognizer using either a Convolutional Neural Network (also known as a ConvNet or CNN) or a Fully Connected Neural Network (also known as a DenseNet). </p> <p>The MNIST dataset is used as training data.</p> </section> <section> <p class="section-head">Training Parameters</p> <div> <label>Model Type:</label> <select id="model-type"> <option>ConvNet</option> <option>DenseNet</option> </select> </div> <div> <label>Number of training epochs:</label> <input id="train-epochs" type="number" value="3" /> </div> <button id="train" disabled>Train Model</button> </section> <section> <p class="section-head">Training Progress</p> <p id="status"></p> <p id="message"></p> <div id="stats"> <div class="canvases"> <label id="loss-label"></label> <div id="loss-canvas"></div> </div> <div class="canvases"> <label id="accuracy-label"></label> <div id="accuracy-canvas"></div> </div> <br /> <div class="canvases"> <div id="loss-val-canvas"></div> </div> <div class="canvases"> <div id="accuracy-val-canvas"></div> </div> </div> </section> <section> <p class="section-head">Inference Examples</p> <div id="images"></div> </section> </div> <!-- TODO(cais): Decide. DO NOT SUBMIT. --> <!-- <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> --> <script type="module" src="index.js"></script> </body> </html>
index.js
import * as tf from '@tensorflow/tfjs'; import {IMAGE_H, IMAGE_W, MnistData} from './data'; import * as ui from './ui'; /** * Creates a model consisting of only flatten, dense and dropout layers. * * The model create here has approximately the same number of parameters * (~31k) as the convnet created by `createConvModel()`, but is * expected to show a significantly worse accuracy after training, due to the * fact that it doesn't utilize the spatial information as the convnet does. * * This is for comparison with the convolutional network above. * * @returns {tf.Model} An instance of tf.Model. */ function createDenseModel() { const model = tf.sequential(); model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]})); model.add(tf.layers.dense({units: 42, activation: 'relu'})); model.add(tf.layers.dense({units: 10, activation: 'softmax'})); return model; // 該普通稠密層模型與卷積神經網絡模型相比,參數數量都在32000個左右,基本維持了一個公平的形式。但在最終的損失和準 // 確率上,普通稠密層模型都比不過後者。雖然與卷積神經網絡的模型相比,準確率差異只有2%左右,但錯誤率卻是幾倍。因此 // 在處理圖片類業務上,卷積神經網絡模型較稠密層模型有明顯的優勢。 } /** * Creates a convolutional neural network (Convnet) for the MNIST data. * * @returns {tf.Model} An instance of tf.Model. */ function createConvModel() { // Create a sequential neural network model. tf.sequential provides an API // for creating "stacked" models where the output from one layer is used as // the input to the next layer. const model = tf.sequential(); // The first layer of the convolutional neural network plays a dual role: // it is both the input layer of the neural network and a layer that performs // the first convolution operation on the input. It receives the 28x28 pixels // black and white images. This input layer uses 16 filters with a kernel size // of 5 pixels each. It uses a simple RELU activation function which pretty // much just looks like this: __/ // 第1層。 model.add(tf.layers.conv2d({ inputShape: [IMAGE_H, IMAGE_W, 1], // 要應用於輸入數據的滑動卷積過濾器窗口的尺寸。在這裏我們將kernelSize設爲3,以指定方形的3*3卷積窗口。 kernelSize: 3, // 要應用於輸入數據的尺寸爲kernelSize的過濾器窗口數量。在這裏我們將對數據應用16個過濾器。 filters: 16, activation: 'relu' })); // After the first layer we include a MaxPooling layer. This acts as a sort of // downsampling using max values in a region instead of averaging. // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks // 第2層。 model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); // Our third layer is another convolution, this time with 32 filters. // 第3層。 // 第3層和第4層這兩個層是前兩個層的完全重複(除了conv2d層在其過濾器配置中具有更大的值並且不具有inputShape字段)。 // 這種由卷積層和極化層組成的幾乎重複的"模體"在convnets中是常見的。它在convnet中扮演着關鍵的角色:特徵的分層提取。 // 爲了理解它的含義,可以考慮一個訓練過的convnet,它的任務是對圖像中的動物進行分類。在convnet的早期階段,卷積層中 // 的濾波器(即channel)可以編碼諸如直線、曲線和角等低級幾何特徵。這些低級特徵轉化爲更復雜的特徵,如貓的眼睛、鼻子和 // 耳朵。在convnet的頂層,一個層可能有對整個cat的存在進行編碼的過濾器。級別越高,表示越抽象,從像素級值中移除的特徵 // 越多。但是,這些抽象的特徵正是convnet任務實現良好精度所需要的,例如,在圖像中時檢測出貓。此外,這些特徵不是手工制 // 作的,而是通過有監督的學習並以自動方式從數據中提取的。這是一個典型的有代表性的例子,我們也把它描述爲層-層轉換。 model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); // Max pooling again. // 第4層。 model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); // Add another convolution layer. // 第5層。 model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); // Now we flatten the output from the 2D filters into a 1D vector to prepare // it for input into our last layer. This is common practice when feeding // higher dimensional data to a final classification output layer. // 第6層。 // 第6層爲扁平層。它將多維張量“壓縮”爲一維張量,從而保持元素的總數。在我們的例子中,形狀爲[3,3,32]的3D張量被展平 // 爲1D張量[288](沒有批次維度)。擠壓操作的一個明顯問題是如何對元素排序,因爲原始三維空間中沒有內在的順序。答案是: // 我們對元素進行排序,這樣,如果你沿着展開的一維張量中的元素向下看,看看它們的原始索引(來自三維張量)如何變化,最後 // 一個索引變化最快,倒數第二個索引變化第二快,以此類推,而第一個索引變化最慢。 // 扁平層用來將輸入“壓平”,即把多維的輸入一維化,常用在從卷積層到全連接層的過渡。扁平層不影響batch的大小。 // 扁平層中沒有權重,它只是將其輸入展開爲一個長數組。 model.add(tf.layers.flatten({})); // Previous flattened 1D vector will be this layer's input. // 第7層。 // 第7層和第8層我們添加了兩個稠密層,爲什麼要添加兩個而不是一個呢?原因是:添加具有非線性激活的層會增加網絡的容量。 // 通常多層神經網絡內,必須含有級聯的線性函數和非線性函數,有助於表示能力的增強,這意味着模型容量的增大,預測精度的 // 提高。在boston-housing項目案例中,我們也做出過相同的總結。https://mapleroyals.com/forum/threads/4th-job-skill-changes.135025 // 實際上,您可以將convnet看作是由兩個模型堆疊在一起: // 1. 包含conv2d、maxPooling2d和flatten層的模型,用於從輸入圖像中提取視覺特徵。 // 2. 一種多層感知器(MLP),具有兩個密集層,以提取的特徵作爲輸入,並基於它進行數字類預測,這就是這兩個密集層的本 // 質。在深度學習中,許多模型都利用了特徵提取層的這種模式,然後最終預測使用MLP。在本書的其餘部分中,我們將看到更多 // 這樣的例子,從音頻信號分類器到自然語言處理。 model.add(tf.layers.dense({ units: 64, activation: 'relu' })); // Our last layer is a dense layer which has 10 output units, one for each // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually // represent numbers, but it's the same idea if you had classes that // represented other entities like dogs and cats (two output classes: 0, 1). // We use the softmax function as the activation for the output layer as it // creates a probability distribution over our 10 classes so their output // values sum to 1. // 第8層。 // 這一層將輸出10個值在0-1之間的數字元素,表示對0-9這10個數字的預測概率,它們的和爲1,最大的概率對應的數值爲預測的傾向值。 model.add(tf.layers.dense({ units: 10, activation: 'softmax' })); return model; } /** * This callback type is used by the `train` function for insertion into the * model.fit callback loop. * * @callback onIterationCallback * @param {string} eventType Selector for which type of event to fire on. * @param {number} batchOrEpochNumber The current epoch / batch number * @param {tf.Logs} logs Logs to append to */ /** * Compile and train the given model. * * @param {tf.Model} model The model to train. * @param {onIterationCallback} onIteration A callback to execute every 10 batches & epoch end. */ async function train(model, onIteration) { ui.logStatus('Training model...'); // We compile our model by specifying an optimizer, a loss function, and a // list of metrics that we will use for model evaluation. Here we're using a // categorical crossentropy loss, the standard choice for a multi-class // classification problem like MNIST digits. // The categorical crossentropy loss is differentiable and hence makes // model training possible. But it is not amenable to easy interpretation // by a human. This is why we include a "metric", namely accuracy, which is // simply a measure of how many of the examples are classified correctly. // This metric is not differentiable and hence cannot be used as the loss // function of the model. model.compile({ // Now that we've defined our model, we will define our optimizer. The // optimizer will be used to optimize our model's weight values during // training so that we can decrease our training loss and increase our // classification accuracy. // We are using rmsprop as our optimizer. // An optimizer is an iterative method for minimizing a loss function. // It tries to find the minimum of our loss function with respect to the // model's weight parameters. optimizer: 'rmsprop', // 損失函數categoricalCrossentropy分類交叉熵,適用於諸如MNIST之類的多分類問題。在iris分類項目案例中,我 // 們也使用了相同的損失函數,一般情況下,當模型的輸出爲概率分佈時,就會使用此函數。分類交叉熵會生成一個數字,指 // 示預測向量與真實標籤向量的相似程度。 loss: 'categoricalCrossentropy', // 度量標準函數調用accuracy。假設預測是基於convnet輸出的10個元素中的最大元素進行的,則此函數度量的示例中的一 // 部分已正確分類。回想一下交叉熵損失和精度度量之間的區別:交叉熵是可微的,因此使基於反向傳播的訓練成爲可能,而精 // 度度量是不可微的,但卻更容易解釋,因此對於分類問題,這是正確預測在所有預測中所佔的百分比。 metrics: ['accuracy'], }); // Batch size is another important hyperparameter. It defines the number of // examples we group together, or batch, between updates to the model's weights // during training. A value that is too low will update weights using too few // examples and will not generalize well. Larger batch sizes require more memory // resources and aren't guaranteed to perform better. // 一般而言,使用較大的批次與較小的批次相比好處是,它對模型的權重產生了更一致且變化較小的漸變更新。但是批次大小越大,訓練 // 期間就需要更多的內存。您還應該記住,在給定相同數量的訓練數據的情況下,較大的批次大小會導致每個時期的梯度更新數量較少。 // 因此,如果您使用較大的批量,請確保相應地增加時期數,以免在訓練過程中無意中減少了權重更新的次數。 const batchSize = 320; // Leave out the last 15% of the training data for validation, to monitor // overfitting during training. // 預留15%的訓練數據用於訓練過程中的驗證。 const validationSplit = 0.15; // Get number of training epochs from the UI. const trainEpochs = ui.getTrainEpochs(); // We'll keep a buffer of loss and accuracy values over time. let trainBatchCount = 0; const trainData = data.getTrainData(); const testData = data.getTestData(); const totalNumBatches = Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) * trainEpochs; // During the long-running fit() call for model training, we include callbacks, // so that we can plot the loss and accuracy values in the page as the training // progresses. let valAcc; await model.fit( trainData.xs, // 特徵輸入。 trainData.labels, // 標籤輸入。 { batchSize, // 每次梯度更新的樣本數。 validationSplit, // 末尾15%的數據用於驗證。 epochs: trainEpochs, // 在訓練數據上的迭代次數。 callbacks: { onBatchEnd: async (batch, logs) => { trainBatchCount++; ui.logStatus( `Training... (${(trainBatchCount / totalNumBatches * 100).toFixed(1)}% complete). ` + 'To stop training, refresh or close page.' ); // 繪製損失和準確率圖表。 ui.plotLoss(trainBatchCount, logs.loss); ui.plotAccuracy(trainBatchCount, logs.acc); // 每10個batch結束時更新測試集的預測結果。 if (onIteration && (batch % 10 === 0)) { onIteration('onBatchEnd', batch, logs); } // 主動讓出線程,允許UI在訓練過程中更新。 await tf.nextFrame(); }, onEpochEnd: async (epoch, logs) => { valAcc = logs.val_acc; // 繪製損失和準確率圖表。 ui.plotValLoss(trainBatchCount, logs.val_loss); ui.plotValAccuracy(trainBatchCount, logs.val_acc); // 每個epoch結束時更新測試集的預測結果。 if (onIteration) { onIteration('onEpochEnd', epoch, logs); } await tf.nextFrame(); } } } ); // 在fit執行完成後(訓練結束),對模型進行評估。 const testResult = model.evaluate(testData.xs, testData.labels); const testAccPercent = testResult[1].dataSync()[0] * 100; const finalValAccPercent = valAcc * 100; // 更新最後的驗證準確率和評估(測試)準確率。 ui.logStatus( `Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` + `Final test accuracy: ${testAccPercent.toFixed(1)}%` ); } /** * Show predictions on a number of test examples. * * @param {tf.Model} model The model to be used for making the predictions. */ async function showPredictions(model) { const testExamples = 100; const examples = data.getTestData(testExamples); // Code wrapped in a tf.tidy() function callback will have their tensors freed // from GPU memory after execution without having to call dispose(). // The tf.tidy callback runs synchronously. tf.tidy(() => { // output爲形狀爲[100, 10]的二維張量,第一維對應100個特徵輸入(手寫數字圖片),第二維對應 // 10個可能的數字的概率。 const output = model.predict(examples.xs); // tf.argMax() returns the indices of the maximum values in the tensor along // a specific axis. Categorical classification tasks like this one often // represent classes as one-hot vectors. One-hot vectors are 1D vectors with // one element for each output class. All values in the vector are 0 // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The // output from model.predict() will be a probability distribution, so we use // argMax to get the index of the vector element that has the highest // probability. This is our prediction. (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3) // dataSync() synchronously downloads the tf.tensor values from the GPU so // that we can use them in our normal CPU JavaScript code // (for a non-blocking version of this function, use data()). // output中的axis 0的每一行,代表某個圖片可能的10個數字的概率,需要找出其中最大的,將其作爲 // 模型的預測值。 // argMax()函數返回沿給定軸的最大值的索引。在這種情況下,此軸是第二維,即const axis = 1。 // argMax()的返回值是形狀爲[100,1]的張量。通過調用 dataSync(),我們將[100,1]形張量轉 // 換爲長度爲100的Float32Array。然後Array.from()將Float32Array轉換爲一個普通的JavaScript // 數組,該數組由100個介於0和9之間的整數組成。此預測數組的含義非常簡單:這是模型對100個輸入圖 // 像進行分類的結果。在MNIST數據集中,目標標籤恰好與輸出索引完全匹配。因此,我們甚至不需要將數 // 組轉換爲字符串標籤。預測數組由下一行使用,該行調用一個UI函數,該函數將分類結果與測試圖像一起呈現。 const axis = 1; const labels = Array.from(examples.labels.argMax(axis).dataSync()); const predictions = Array.from(output.argMax(axis).dataSync()); // 更新測試集的預測結果,對100個圖標註對它的預測數字和正確性(綠色)。 ui.showTestResults(examples, predictions, labels); }); } function createModel() { let model; const modelType = ui.getModelTypeId(); if (modelType === 'ConvNet') { model = createConvModel(); } else if (modelType === 'DenseNet') { model = createDenseModel(); } else { throw new Error(`Invalid model type: ${modelType}`); } return model; } let data; // When page ready, loads the MNIST data, trains the model, and then shows what // the model predicted on unseen test data. window.onload = async () => { ui.logStatus('Loading MNIST data...'); data = new MnistData(); await data.load(); ui.logStatus('Creating model...'); const model = createModel(); // 輸出模型規格。 model.summary(); // __________________________________________________________________________________________ // Layer (type) Input Shape Output shape Param # // ========================================================================================== // conv2d_Conv2D1 (Conv2D) [[null,28,28,1]] [null,26,26,16] 160 // __________________________________________________________________________________________ // max_pooling2d_MaxPooling2D1 [[null,26,26,16]] [null,13,13,16] 0 // __________________________________________________________________________________________ // conv2d_Conv2D2 (Conv2D) [[null,13,13,16]] [null,11,11,32] 4640 // __________________________________________________________________________________________ // max_pooling2d_MaxPooling2D2 [[null,11,11,32]] [null,5,5,32] 0 // __________________________________________________________________________________________ // conv2d_Conv2D3 (Conv2D) [[null,5,5,32]] [null,3,3,32] 9248 // __________________________________________________________________________________________ // flatten_Flatten1 (Flatten) [[null,3,3,32]] [null,288] 0 // __________________________________________________________________________________________ // dense_Dense1 (Dense) [[null,288]] [null,64] 18496 // __________________________________________________________________________________________ // dense_Dense2 (Dense) [[null,64]] [null,10] 650 // ========================================================================================== // Total params: 33194 // Trainable params: 33194 // Non-trainable params: 0 // __________________________________________________________________________________________ // // 從上述圖表信息可見該模型總參數爲33194個,其中池化層、扁平層無參數。 // 在深度學習中,通常模型自身的參數和模型的輸出會佔用顯存。 // 有參數的層主要包括:卷積層、全連接層、BatchNorm層、Embedding層等。 // 無參數的層主要包括:多數的激活層(Sigmoid/ReLU)、池化層、Dropout層、扁平層等。 ui.logStatus('Model ready.'); // Enable the "Train Model" button. document.getElementById('train').removeAttribute('disabled'); ui.setTrainButtonCallback(async () => { ui.logStatus('Starting model training...'); document.getElementById('images').innerHTML = ''; await train(model, () => showPredictions(model)); }); };
data.js
import * as tf from '@tensorflow/tfjs'; // This is a helper class for loading and managing MNIST data specifically. // It is a useful example of how you could create your own data manager class // for arbitrary data though. It's worth a look :). // MnistData類封裝了對MNIST灰度圖片訓練集和標籤集的數據加載和預處理過程,爲了使數據對下層算法和框架有友好性,數據的 // 格式和存儲都以二進制緩衝區和類型化數組爲主。這裏的數據處理相對於以JSON爲主的數據處理的普通web應用而言是稍顯複雜的。 // 安東尼·戈德布盧姆(Kaggle的CEO)曾經這樣說過:有人開玩笑說有80%的數據科學家在清理數據,剩下的20%在抱怨清理數據。 // 在實際的數據科學(大數據、機器學習)工作中,清理數據所佔比例比外人想象的要多得多。一般而言,訓練模型通常只佔機器學習 // 或數據科學家工作的一小部分(少於10%)。當下在機器學習的純應用領域上,以成熟框架和算法作爲基礎的開發上,更是如此。 const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; export const IMAGE_H = 28; export const IMAGE_W = 28; const IMAGE_SIZE = IMAGE_H * IMAGE_W; const NUM_CLASSES = 10; // 0-9。 const NUM_DATASET_ELEMENTS = 65000; // MNIST精靈圖有65000個子圖,精靈圖的高度也爲65000像素。 const NUM_TRAIN_ELEMENTS = 55000; // 我們用於訓練的子圖爲55000個。 // A class that fetches the sprited MNIST dataset and provide data as tf.Tensors. export class MnistData { constructor() { } async load() { // Make a request for the MNIST sprited image. let img = new Image(); // 這裏我們用到了canvas是因爲DOM的img對象並沒有提供給我們獲取其像素的API,DOM這個級別的API是無法操作像素的。 // 只能通過canvas來做一箇中間層,將img內容分批次繪製到canvas上,再從canvas提取出像素數據,看下面for循環部分。 const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d', { willReadFrequently: true, }); // 請求手寫數字灰度圖片,並將圖片轉化爲數據格式。 const imgRequest = new Promise((resolve) => { // CORS配置。 img.crossOrigin = ''; // naturalWidth和naturalHeight指圖片的原始大小,在計算時可以強制校正圖片尺寸。 img.onload = () => { img.width = img.naturalWidth; // 圖片本身寬度784。 img.height = img.naturalHeight; // 圖片本身長度65000。 // 初始化一個新的二進制緩衝區,包含整個MMINST精靈圖的每個像素(每一張子圖的每個像素)。它將圖像總數和每張圖像的尺寸和通道數量相乘。 // 圖片爲PNG格式,所以有rgba4個通道,最後乘以4。 const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); // 65000 * (28 * 28) * 4。 const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; // 分13次處理完,每次處理MNIST精靈圖中的5000個子圖。 // NUM_DATASET_ELEMENTS和img.height一致,都爲65000。每次取5000像素的高度,分13次就能取完。爲什麼它們兩個的一致?我們可以這 // 麼看:因爲精靈圖寬度爲784像素,784整好是每個子圖的總像素,因爲單個子圖的寬度爲28*28,整好就是784,所以MNIST一橫排的像素總量正 // 好就是一個子圖的像素總量,從計數上就可以看成一個子圖佔了MNIST的一排,MNIST有多少排就有多少張子圖,MNIST有65000像素高,也就是 // 有65000排,也就是我們可以算出子圖有65000張,這個數量和實際的子圖數量是一致的。從獲取到的MNIST精靈圖來看,子圖並不是以28*28的 // 長寬方位比依次存放在MNIST上的,而是以1*784的長寬方位比來存放的,正好就是和我們想的一致:一排就是一張子圖。將每張子圖按照28來切割 // 並換行,就可以組成一個28*28的方形原始圖,這下我們就能看明白手寫的數字了,雖然這種排列方式利於眼睛的分辨,但顯然不適合像素的提取, // 一旦這樣排列,我們在提取子圖像素的時候,就變成了去處理矩陣數據(單個子圖的正方形就可以看做是一個28*28矩陣)。 // 下面內容涉及到的兩個for循環,其中外層的循環遍歷chunkSize,內層的循環遍歷某個chunkSize下的像素。 const chunkNum = NUM_DATASET_ELEMENTS / chunkSize; // => 65000 / 5000 = 13。 for (let i = 0; i < chunkNum; i++) { // drawImage允許我們從一個圖片源上的裁剪某個特定矩形區域並繪製到canvas的特定位置、大小的區域上。 // 這裏我們從image上裁剪同image寬度,高度爲chunkSize的區域繪製到canvas上。下一次循環就裁剪下 // 一個chunkSize高度的區域。 ctx.drawImage( img, // 繪製到上下文的元素,允許任何畫布圖像源,例如:HTMLImageElement、SVGImageElement等。 // 以下參數都基於需要繪製到目標上下文的image的矩形(裁剪)選擇框。 0, // 左上角X軸座標。 i * chunkSize, // 左上角Y軸座標。 img.width, // 寬度。 chunkSize, // 高度。 // 以下參數都基於image的左上角在目標畫布上的繪製。 0, // 左上角X軸座標。 0, // 左上角Y軸座標。 img.width, // 寬度。 chunkSize // 高度。 ); // getImageData返回一個ImageData對象,包含canvas給定的矩形區域的像素數據,這裏的話每個循環裏的canvas上線文 // 就是15680000個像素(784 * 5000 * 4 = 15680000,4個通道總的需要乘以4)。ImageData結構示例: // { // width: 100, // height: 100, // colorSpace: 'srgb', // data: Uint8ClampedArray[40000] // 像素數據。 // } const imageData = ctx.getImageData( // 以下參數都基於將要被提取的圖像數據矩形區域。 0, // 左上角x座標。 0, // 左上角y座標。 canvas.width, // 寬度。 canvas.height // 高度。 ); // datasetBytesView長度爲3920000,即精靈圖中chunkSize個子圖的像素總數量((28 * 28) * 5000)。 // 採用Float32類型數組當做數據視圖來操作它,這個視圖即表示了datasetBytesBuffer的一段空間。 const datasetBytesView = new Float32Array( // length => 3920000 // 數據。 datasetBytesBuffer, // 起始偏移量。第一次循環的時候i爲0,也就是從0開始,即無偏移。第二次循環代的時候i爲1,即偏移從5000個子圖的四通道數 // 量以後開始。剩下的以此類推,每次循環都處理完當前chunkSize的數據,循環完13個chunkSize後,即整個MNIST精靈圖的 // 數據就在datasetBytesBuffer中了。也就是說每次對chunkSize的迭代處理中,其實都是針對datasetBytesBuffer進行 // 處理,往它裏面存放數據,也就是下面的內層for循環中,針對當前chunkSize下的每個像素,對datasetBytesBuffer從 // 當前的偏移值i * IMAGE_SIZE * chunkSize * 4開始存儲,共存儲IMAGE_SIZE * chunkSize個數據。 i * IMAGE_SIZE * chunkSize * 4, // => i * (28 * 28) * 5000 * 4。 // 元素數量。 IMAGE_SIZE * chunkSize // => (28 * 28) * 5000 = 392000 ); // 將該區域內總的像素數量除以4得到單通道像素數量。因爲原數據包含了4個通道的數據,除以4才表示單個通道的數據長度,也就是該矩 // 形區域內所有像素的按單通道計算的數量。等同於MAGE_SIZE * chunkSize,也就是本次for循環的datasetBytesView的元素長度。 const len = imageData.data.length / 4; // => 15680000 / 4 = 3920000。 // 遍歷當前chunkSize高度區域內的每一個像素(按單通道計算)。通過datasetBytesView視圖將數據放入datasetBytesBuffer // 中的對應位置,來改變datasetBytesBuffer的值。 for (let j = 0; j < len; j++) { // All channels hold an equal value since the image is grayscale, so just read the red channel. // 因爲圖片爲灰度化的,所有rgba通道都有相同的值,因此只需要讀取r(紅色)通道的值即可(間隔4個取1次,就能保證每次取的 // 都是r通道的值因此乘以了4)。最後將取出的r通道值除以255,得到0-1之間的數。賦值給datasetBytesView視圖對應的下標, // datasetBytesView和datasetBytesBuffer是連通的,視圖被改,datasetBytesBuffer就對應改變了,整個外層for循 // 環結束後,datasetBytesBuffer(this.datasetImages)就是整個MINST精靈圖的所有像素的總數據(灰度數據)。 // 這裏除以255是爲了避免訓練數據與預測數據之間出現不匹配的情況,因爲我們的MNIST卷積網絡應使用歸一化到0-1之間的圖像張 // 量數據進行訓練。 datasetBytesView[j] = imageData.data[j * 4] / 255; } // 當上面這個for循環完成後,我們就得到了該chunkSize高度區域的所有的r通道的像素轉成灰度的值。然後再接着遍歷下一個chunkSize // 高度區域,並獲得它的所有r通道的像素數據。 } // chunkSize高度區域都全部遍歷完以後,datasetBytesBuffer裏就已經有了整個MNIST精靈圖的r通道的灰度像素信息。最終datasetImages // 的長度爲50960000,即整個MNIST精靈圖的像素總數量(65000 * (28 * 28))。 // 採用Float32類型數組當做數據視圖來操作它。 this.datasetImages = new Float32Array(datasetBytesBuffer); // length => 50960000 // 圖片DOM用完就回收。 img = null; resolve(); }; // 賦值src觸發圖片onload加載回調。 img.src = MNIST_IMAGES_SPRITE_PATH; }); // 請求標籤數據。 const labelsRequest = fetch(MNIST_LABELS_PATH); // 等待圖片、標籤數據都加載並處理完成。 const [, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); // 將標籤數據(uint8)存在一個二進制緩衝區中(字節數組),並採用Uint8類型數組當做數據視圖來操作它。 // 標籤數據爲針對每個數字的獨熱編碼集合。以數字6舉例: // 樣本 0 1 2 3 4 5 6 7 8 9 // 標籤 0 0 0 0 0 0 1 0 0 0 // 只有一個是1,即數字6獨熱。 // 預測 0.10 0.01 0.01 0.01 0.09 0.01 0.71 0.01 0.03 0.02 // datasetLabels爲一個1維數組,數據長度爲65000 * 10,因爲每組獨熱編碼長度爲10,分別對應10個數字的編碼,只有一個爲1,其他爲0。 // datasetLabels每10個下標的元素即爲一組針對一個數字的真實值的獨熱編碼。所以下面在劃分的時候乘以了NUM_CLASSES(10)。 this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Slice the images and labels into train and test sets. // 將圖片數據和標籤數據劃分爲訓練集和測試集。 this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); // 0到(28 * 28) * 55000。 this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); // 上面剩下的。 this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 0到10 * 55000。 this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); // 上面剩下的。 } /** * Get all training data as a data tensor and a label tensor. * * @returns * xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]` * labels: The one-hot encoded labels tensor, of shape '[numTrainExamples, 10]'. */ getTrainData() { // 這包括表示爲NHWC形狀[N,28、28、1]的4維張量(批次示例的第一維)的輸入MNIST圖像,其中N是圖像總數。 const xs = tf.tensor4d( this.trainImages, // NHWC:[55000, 28, 28, 1]。 [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1] ); // 這包括輸入標籤,表示爲形狀爲[N,10]的獨熱編碼2維張量。 const labels = tf.tensor2d( this.trainLabels, // NHWC:[55000, 10]。 [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES] ); // 返回的訓練數據就是標準的特徵集和標籤集配對。 return {xs, labels}; } /** * Get all test data as a data tensor and a labels tensor. * * @param {number} numExamples Optional number of examples to get. If not provided, * all test examples will be returned. * @returns xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`. * labels: The one-hot encoded labels tensor, of shape `[numTestExamples, 10]`. */ getTestData(numExamples) { let xs = tf.tensor4d( this.testImages, [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1] ); let labels = tf.tensor2d( this.testLabels, [this.testLabels.length / NUM_CLASSES, NUM_CLASSES] ); if (numExamples != null) { // 修正形狀爲[100, 28, 28, 1],即從測試集中取出100張圖片的特徵和標籤。 xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]); labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]); } // 返回的測試數據就是標準的特徵集和標籤集配對。這與訓練集是相似的,只是它不包含在訓練集以內,模型未曾接觸到。 return {xs, labels}; } }
ui.js
import * as tfvis from '@tensorflow/tfjs-vis'; // This is a helper class for drawing loss graphs and MNIST images to the // window. For the purposes of understanding the machine learning bits, you can // largely ignore it const statusElement = document.getElementById('status'); const messageElement = document.getElementById('message'); const imagesElement = document.getElementById('images'); const trainButton = document.getElementById('train'); const modelType = document.getElementById('model-type'); const epochInput = document.getElementById('train-epochs'); const lossLabelElement = document.getElementById('loss-label'); const lossContainer = document.getElementById('loss-canvas'); const lossValContainer = document.getElementById('loss-val-canvas'); const accuracyLabelElement = document.getElementById('accuracy-label'); const accuracyContainer = document.getElementById('accuracy-canvas'); const accuracyValContainer = document.getElementById('accuracy-val-canvas'); export function logStatus(message) { statusElement.innerText = message; } export function trainingLog(message) { messageElement.innerText = `${message}\n`; } export function showTestResults(batch, predictions, labels) { const testExamples = batch.xs.shape[0]; imagesElement.innerHTML = ''; for (let i = 0; i < testExamples; i++) { const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]); const div = document.createElement('div'); div.className = 'pred-container'; // 預測的值。 const pred = document.createElement('div'); // 判斷預測是否正確。 const prediction = predictions[i]; const label = labels[i]; const correct = prediction === label; // 預測正確顯示綠色,錯誤顯示紅色。 pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`; // 預測的值。 pred.innerText = `pred: ${prediction}`; // 創建一個canvas來展示手寫數字灰度圖片。 const canvas = document.createElement('canvas'); canvas.className = 'prediction-canvas'; draw(image.flatten(), canvas); div.appendChild(pred); div.appendChild(canvas); imagesElement.appendChild(div); } } const lossValues = []; const lossValValues = []; const accuracyValues = []; const accuracyValValues = []; export function plotLoss(batch, loss) { lossValues.push({ x: batch, y: loss }); tfvis.render.linechart( lossContainer, { values: lossValues, series: ['train'] }, { xLabel: 'Batch Number', yLabel: 'Loss', width: 400, height: 300, } ); lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`; } export function plotValLoss(batch, loss) { lossValValues.push({ x: batch, y: loss }); tfvis.render.linechart( lossValContainer, { values: lossValValues, series: ['validation'] }, { xLabel: 'Batch Number', yLabel: 'Loss', width: 400, height: 300, seriesColors: ['#f16528'] } ); } export function plotAccuracy(batch, accuracy) { accuracyValues.push({ x: batch, y: accuracy }); tfvis.render.linechart( accuracyContainer, { values: accuracyValues, series: ['train'] }, { xLabel: 'Batch Number', yLabel: 'Accuracy', width: 400, height: 300, } ); accuracyLabelElement.innerText = `last accuracy: ${(accuracy * 100).toFixed(1)}%`; } export function plotValAccuracy(batch, accuracy) { accuracyValValues.push({ x: batch, y: accuracy }); tfvis.render.linechart( accuracyValContainer, { values: accuracyValValues, series: ['validation'] }, { xLabel: 'Batch Number', yLabel: 'Accuracy', width: 400, height: 300, seriesColors: ['#f16528'] } ); } export function draw(image, canvas) { const [width, height] = [28, 28]; canvas.width = width; canvas.height = height; const ctx = canvas.getContext('2d'); const imageData = new ImageData(width, height); const data = image.dataSync(); for (let i = 0; i < height * width; ++i) { const j = i * 4; imageData.data[j + 0] = data[i] * 255; imageData.data[j + 1] = data[i] * 255; imageData.data[j + 2] = data[i] * 255; imageData.data[j + 3] = 255; } ctx.putImageData(imageData, 0, 0); } export function getModelTypeId() { return document.getElementById('model-type').value; } export function getTrainEpochs() { return Number.parseInt(document.getElementById('train-epochs').value); } export function setTrainButtonCallback(train) { trainButton.addEventListener('click', async () => { // Disable button during the training. trainButton.setAttribute('disabled', true); modelType.setAttribute('disabled', true); epochInput.setAttribute('disabled', true); // Start training. await train(); // Release button and reset chart data for next training without refreshing page. trainButton.removeAttribute('disabled'); modelType.removeAttribute('disabled'); epochInput.removeAttribute('disabled'); // Rest data array. lossValues.splice(0, lossValues.length); lossValValues.splice(0, lossValValues.length); accuracyValues.splice(0, accuracyValues.length); accuracyValValues.splice(0, accuracyValValues.length); }); }