AI從入門到入門之手寫數字識別模型java方式Dense全連接神經網絡實現

前言:授人以魚不如授人以漁.先學會用,在學原理,在學創造,可能一輩子用不到這種能力,但是不能不具備這種能力。這篇文章主要是介紹算法入門Helloword之手寫圖片識別模型java中如何實現以及部分解釋。目前大家對於人工智能-機器學習-神經網絡的文章都是基於python語言的,對於擅長java的後端小夥伴想要去了解就不是特別友好,所以這裏給大家介紹一下如何在java中實現,打開新世界的大門。以下爲本人個人理解如有錯誤歡迎指正

一、目標:使用MNIST數據集訓練手寫數字圖片識別模型

在實現一個模型的時候我們要準備哪些知識體系:

1.機器學習基礎:包括監督學習、無監督學習、強化學習等基本概念。

2.數據處理與分析:數據清洗、特徵工程、數據可視化等。

3.編程語言:如Python,用於實現機器學習算法。

4.數學基礎:線性代數、概率統計、微積分等數學知識。

5.機器學習算法:線性迴歸、決策樹、神經網絡、支持向量機等算法。

6.深度學習框架:如TensorFlow、PyTorch等,用於構建和訓練深度學習模型。

7.模型評估與優化:交叉驗證、超參數調優、模型評估指標等。

8.實踐經驗:通過實際項目和競賽積累經驗,不斷提升模型學習能力。

這裏的機器學習HelloWorld是手寫圖片識別用的是TensorFlow框架

主要需要:

1.理解手寫圖片的數據集,訓練集是什麼樣的數據(60000,28,28) 、訓練集的標籤是什麼樣的(1)

2.理解激活函數的作用

3.正向傳遞和反向傳播的作用以及實現

4.訓練模型和保存模型

5.加載保存的模型使用

二、java代碼與python代碼對比分析

因爲python代碼解釋網上已經有很多了,這裏不在重複解釋

1.數據集的加載

python中

def load_data(dpata_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
    return (train_x, train_y), (test_x, test_y)
(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim)  # 數據集的維度
print(train_x.shape)  # 數據集的形狀
print(len(train_x))  # 數據集的大小
print(train_x)  # 數據集
print("---查看單個數據")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
print("---查看單個數據")
print(train_y[3])





 

java中

SimpleMnist.class

 private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
    private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
    private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
    private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";
//加載數據
MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);

MnistDataset.class

  /**
     * @param trainingImagesArchive 訓練圖片路徑
     * @param trainingLabelsArchive 訓練標籤路徑
     * @param testImagesArchive     測試圖片路徑
     * @param testLabelsArchive     測試標籤路徑
     */
    public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);
            trainingImages.slice(sliceFrom(0));
            trainingLabels.slice(sliceTo(0));
            // 切片操作
            Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
            ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
            ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作
            if (index >= 0) {
                return new MnistDataset(trainingImages,trainingLabels,validationImage,validationLable,testImages,testLabels);
            } else {
                return null;
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }  
    private static ByteNdArray readArchive(String archiveName) throws IOException {
        System.out.println("archiveName = " + archiveName);
        DataInputStream archiveStream = new DataInputStream(new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
        );
        archiveStream.readShort(); // first two bytes are always 0
        byte magic = archiveStream.readByte();
        if (magic != TYPE_UBYTE) {
            throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
        }
        int numDims = archiveStream.readByte();
        long[] dimSizes = new long[numDims];
        int size = 1;  // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
        for (int i = 0; i < dimSizes.length; ++i) {
            dimSizes[i] = archiveStream.readInt();
            size *= dimSizes[i];
        }
        byte[] bytes = new byte[size];
        archiveStream.readFully(bytes);
        return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
    }
    /**
     * Mnist 數據集構造器
     */
    private MnistDataset(ByteNdArray trainingImages, ByteNdArray trainingLabels,ByteNdArray validationImages,ByteNdArray validationLabels,ByteNdArray testImages,ByteNdArray testLabels
    ) {
        this.trainingImages = trainingImages;
        this.trainingLabels = trainingLabels;
        this.validationImages = validationImages;
        this.validationLabels = validationLabels;
        this.testImages = testImages;
        this.testLabels = testLabels;
        this.imageSize = trainingImages.get(0).shape().size();
        System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
        System.out.println("數據集的維度:" + trainingImages.rank());
        System.out.println("數據集的形狀 = " + trainingImages.shape());
        System.out.println("數據集的大小 = " + trainingImages.shape().get(0));
        System.out.println("查看單個數據 = " + trainingImages.get(0));
    }





 

2.模型構建

python中

model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))  # 添加Flatten層說明輸入數據的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))  # 添加隱含層,爲全連接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))  # 添加輸出層,爲全連接層,10個節點,softmax激活函數
print("打印模型結構")
# 使用 summary 打印模型結構
print('\n', model.summary())  # 查看網絡結構和參數信息
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

java中

SimpleMnist.class

        Ops tf = Ops.create(graph);
        // Create placeholders and variables, which should fit batches of an unknown number of images
        //創建佔位符和變量,這些佔位符和變量應適合未知數量的圖像批次
        Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
        Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);

        // Create weights with an initial value of 0
        // 創建初始值爲 0 的權重
        Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
        Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
        
        // Create biases with an initial value of 0
        //創建初始值爲 0 的偏置
        Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
        Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));

        // Predict the class of each image in the batch and compute the loss
        //使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,並加上偏置項 biases。
        //wx+b
        MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
        Add<TFloat32> add = tf.math.add(matMul, biases);
        //Softmax 是一個常用的激活函數,它將輸入轉換爲表示概率分佈的輸出。對於輸入向量中的每個元素,Softmax 函數會計算指數,
        //並對所有元素求和,然後將每個元素的指數除以總和,最終得到一個概率分佈。這通常用於多分類問題,以輸出每個類別的概率
        Softmax<TFloat32> softmax = tf.nn.softmax(add);

        // 創建一個計算交叉熵的Mean對象
        Mean<TFloat32> crossEntropy =
                tf.math.mean(  // 計算張量的平均值
                        tf.math.neg(  // 計算張量的負值
                                tf.reduceSum(  // 計算張量的和
                                        tf.math.mul(labels, tf.math.log(softmax)),  //計算標籤和softmax預測的對數乘積
                                        tf.array(1)  // 在指定軸上求和
                                )
                        ),
                        tf.array(0)  // 在指定軸上求平均值
                );

        // Back-propagate gradients to variables for training
        //使用梯度下降優化器來最小化交叉熵損失函數。首先,創建了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
        Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
        Op minimize = optimizer.minimize(crossEntropy);

3.訓練模型

python中

history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)

java中

SimpleMnist.class

            // Train the model
            for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
                try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
                     TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
                    // 創建會話運行器
                    session.runner()
                            // 添加要最小化的目標
                            .addTarget(minimize)
                            // 通過feed方法將圖像數據輸入到模型中
                            .feed(images.asOutput(), batchImages)
                            // 通過feed方法將標籤數據輸入到模型中
                            .feed(labels.asOutput(), batchLabels)
                            // 運行會話
                            .run();
                }
            }

4.模型評估

python中

test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2)  # 每次迭代輸出一條記錄,來評價該模型是否有比較好的泛化能力
print('Test 損失: %.3f' % test_loss)
print('Test 精確度: %.3f' % test_acc)

java中

SimpleMnist.class

   // Test the model
            ImageBatch testBatch = dataset.testBatch();
            try (TFloat32 testImages = preprocessImages(testBatch.images());
                 TFloat32 testLabels = preprocessLabels(testBatch.labels());
                 // 定義一個TFloat32類型的變量accuracyValue,用於存儲計算得到的準確率值
                 TFloat32 accuracyValue = (TFloat32) session.runner()
                         // 從會話中獲取準確率值
                         .fetch(accuracy)
                         .fetch(predicted)
                         .fetch(expected)
                         // 將images作爲輸入,testImages作爲數據進行餵養
                         .feed(images.asOutput(), testImages)
                         // 將labels作爲輸入,testLabels作爲數據進行餵養
                         .feed(labels.asOutput(), testLabels)
                         // 運行會話並獲取結果
                         .run()
                         // 獲取第一個結果並存儲在accuracyValue中
                         .get(0)) {
                System.out.println("Accuracy: " + accuracyValue.getFloat());
            }

5.保存模型

python中

# 使用save_model保存完整模型
# save_model(model, '/media/cfs/用戶ERP名稱/ea/saved_model', save_format='pb')
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')

java中

SimpleMnist.class

            // 保存模型
            SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
            Signature.Builder builder = Signature.builder();
            builder.input("images", images);
            builder.input("labels", labels);
            builder.output("accuracy", accuracy);
            builder.output("expected", expected);
            builder.output("predicted", predicted);
            Signature signature = builder.build();
            SessionFunction sessionFunction = SessionFunction.create(signature, session);
            exporter.withFunction(sessionFunction);
            exporter.export();

6.加載模型

python中

 # 加載.pb模型文件
    global load_model
    load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
    load_model.summary()
    demo = tensorflow.reshape(test_x, (1, 28, 28))
    input_data = np.array(demo)  # 準備你的輸入數據
    input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
    predictValue = load_model.predict(input_data)
    print("predictValue")
    print(predictValue)
    y_pred = np.argmax(predictValue)
    print('標籤值:' + str(test_y) + '\n預測值:' + str(y_pred))
    return y_pred, test_y,

java中

SimpleMnist.class

	//加載模型並預測
    public void loadModel(String exportDir) {
        // load saved model
        SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
        try {
            printSignature(model);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        ByteNdArray validationImages = dataset.getValidationImages();
        ByteNdArray validationLabels = dataset.getValidationLabels();
        TFloat32 testImages = preprocessImages(validationImages);
        System.out.println("testImages = " + testImages.shape());
        TFloat32 testLabels = preprocessLabels(validationLabels);
        System.out.println("testLabels = " + testLabels.shape());
        Result run = model.session().runner()
                .feed("Placeholder:0", testImages)
                .feed("Placeholder_1:0", testLabels)
                .fetch("ArgMax:0")
                .fetch("ArgMax_1:0")
                .fetch("Mean_1:0")
                .run();
        // 處理輸出
        Optional<Tensor> tensor1 = run.get("ArgMax:0");
        Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
        Optional<Tensor> tensor3 = run.get("Mean_1:0");
        TInt64 predicted = (TInt64) tensor1.get();
        Long predictedValue = predicted.getObject(0);
        System.out.println("predictedValue = " + predictedValue);
        TInt64 expected = (TInt64) tensor2.get();
        Long expectedValue = expected.getObject(0);
        System.out.println("expectedValue = " + expectedValue);
        TFloat32 accuracy = (TFloat32) tensor3.get();
        System.out.println("accuracy = " + accuracy.getFloat());
    }
    //打印模型信息
    private static void printSignature(SavedModelBundle model) throws Exception {
        MetaGraphDef m = model.metaGraphDef();
        SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
        int numInputs = sig.getInputsCount();
        int i = 1;
        System.out.println("MODEL SIGNATURE");
        System.out.println("Inputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
        }
        int numOutputs = sig.getOutputsCount();
        i = 1;
        System.out.println("Outputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
        }
    }

三、完整的python代碼

本工程使用環境爲

Python: 3.7.9

https://www.python.org/downloads/windows/

Anaconda: Python 3.11 Anaconda3-2023.09-0-Windows-x86_64

https://www.anaconda.com/download#downloads

tensorflow:2.0.0

直接從anaconda下安裝

mnistTrainDemo.py

import gzip
import os.path
import tensorflow as tensorflow
from tensorflow import keras
# 可視化 image
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import save_model

# 加載數據
# mnist = keras.datasets.mnist
# mnistData = mnist.load_data() #Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- unknown url type: https
"""
這裏可以直接使用
mnist = keras.datasets.mnist
mnistData = mnist.load_data() 加載數據,但是有的時候不成功,所以使用本地加載數據
"""
def load_data(data_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))

    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)

    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)

    return (train_x, train_y), (test_x, test_y)

(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s' % (train_x.shape, train_y.shape, test_x.shape, test_y.shape))
print(train_x.ndim)  # 數據集的維度
print(train_x.shape)  # 數據集的形狀
print(len(train_x))  # 數據集的大小
print(train_x)  # 數據集
print("---查看單個數據")
print(train_x[0])
print(len(train_x[0]))
print(len(train_x[0][1]))
print(train_x[0][6])
# 可視化image圖片、一副image的數據
# plt.imshow(train_x[0].reshape(28, 28), cmap="binary")
# plt.show()
print("---查看單個數據")
print(train_y[0])

# 數據預處理
# 歸一化、並轉換爲tensor張量,數據類型爲float32.  ---歸一化也可能造成識別率低
# train_x, test_x = tensorflow.cast(train_x / 255.0, tensorflow.float32), tensorflow.cast(test_x / 255.0,
#                                                                                         tensorflow.float32),
# train_y, test_y = tensorflow.cast(train_y, tensorflow.int16), tensorflow.cast(test_y, tensorflow.int16)
# print("---查看單個數據歸一後的數據")
# print(train_x[0][6])  # 30/255=0.11764706  ---歸一化每個值除以255
# print(train_y[0])

# Step2: 配置網絡 建立模型
'''
以下的代碼判斷就是定義一個簡單的多層感知器,一共有三層,
兩個大小爲100的隱層和一個大小爲10的輸出層,因爲MNIST數據集是手寫0到9的灰度圖像,
類別有10個,所以最後的輸出大小是10。最後輸出層的激活函數是Softmax,
所以最後的輸出層相當於一個分類器。加上一個輸入層的話,
多層感知器的結構是:輸入層-->>隱層-->>隱層-->>輸出層。
激活函數 https://zhuanlan.zhihu.com/p/337902763
'''
# 構造模型
# model = keras.Sequential([
#     # 在第一層的網絡中,我們的輸入形狀是28*28,這裏的形狀就是圖片的長度和寬度。
#     keras.layers.Flatten(input_shape=(28, 28)),
#     # 所以神經網絡有點像濾波器(過濾裝置),輸入一組28*28像素的圖片後,輸出10個類別的判斷結果。那這個128的數字是做什麼用的呢?
#     # 我們可以這樣想象,神經網絡中有128個函數,每個函數都有自己的參數。
#     # 我們給這些函數進行一個編號,f0,f1…f127 ,我們想的是當圖片的像素一一帶入這128個函數後,這些函數的組合最終輸出一個標籤值,在這個樣例中,我們希望它輸出09 。
#     # 爲了得到這個結果,計算機必須要搞清楚這128個函數的具體參數,之後才能計算各個圖片的標籤。這裏的邏輯是,一旦計算機搞清楚了這些參數,那它就能夠認出不同的10個類別的事物了。
#     keras.layers.Dense(100, activation=tensorflow.nn.relu),
#     # 最後一層是10,是數據集中各種類別的代號,數據集總共有10類,這裏就是10 。
#     keras.layers.Dense(10, activation=tensorflow.nn.softmax)
# ])

model = tensorflow.keras.Sequential()
model.add(tensorflow.keras.layers.Flatten(input_shape=(28, 28)))  # 添加Flatten層說明輸入數據的形狀
model.add(tensorflow.keras.layers.Dense(128, activation='relu'))  # 添加隱含層,爲全連接層,128個節點,relu激活函數
model.add(tensorflow.keras.layers.Dense(10, activation='softmax'))  # 添加輸出層,爲全連接層,10個節點,softmax激活函數
print("打印模型結構")
# 使用 summary 打印模型結構
# print(model.summary())
print('\n', model.summary())  # 查看網絡結構和參數信息

'''
接着是配置模型,在這一步,我們需要指定模型訓練時所使用的優化算法與損失函數,
此外,這裏我們也可以定義計算精度相關的API。
優化器https://zhuanlan.zhihu.com/p/27449596
'''
# 配置模型  配置模型訓練方法
# 設置神經網絡的優化器和損失函數。# 使用Adam算法進行優化   # 使用CrossEntropyLoss 計算損失 # 使用Accuracy 計算精度
# model.compile(optimizer=tensorflow.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# adam算法參數採用keras默認的公開參數,損失函數採用稀疏交叉熵損失函數,準確率採用稀疏分類準確率函數
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

# Step3:模型訓練
# 開始模型訓練
# model.fit(x_train,  # 設置訓練數據集
#           y_train,
#           epochs=5,  # 設置訓練輪數
#           batch_size=64,  # 設置 batch_size
#           verbose=1)  # 設置日誌打印格式
# 批量訓練大小爲64,迭代5次,測試集比例0.2(48000條訓練集數據,12000條測試集數據)
history = model.fit(train_x, train_y, batch_size=64, epochs=5, validation_split=0.2)

# STEP4: 模型評估
# 評估模型,不輸出預測結果輸出損失和精確度. test_loss損失,test_acc精確度
test_loss, test_acc = model.evaluate(test_x, test_y)
model.evaluate(test_x, test_y, verbose=2)  # 每次迭代輸出一條記錄,來評價該模型是否有比較好的泛化能力
# model.evaluate(test_dataset, verbose=1)
print('Test 損失: %.3f' % test_loss)
print('Test 精確度: %.3f' % test_acc)
# 結果可視化
print(history.history)
loss = history.history['loss']  # 訓練集損失
val_loss = history.history['val_loss']  # 測試集損失
acc = history.history['sparse_categorical_accuracy']  # 訓練集準確率
val_acc = history.history['val_sparse_categorical_accuracy']  # 測試集準確率

plt.figure(figsize=(10, 3))
plt.subplot(121)
plt.plot(loss, color='b', label='train')
plt.plot(val_loss, color='r', label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc, color='b', label='train')
plt.plot(val_acc, color='r', label='test')
plt.ylabel('Accuracy')
plt.legend()

# 暫停5秒關閉畫布,否則畫布一直打開的同時,會持續佔用GPU內存
# plt.ion()  # 打開交互式操作模式
# plt.show()
# plt.pause(5)
# plt.close()
# plt.show()

# Step5:模型預測 輸入測試數據,輸出預測結果
for i in range(1):
    num = np.random.randint(1, 10000)  # 在1~10000之間生成隨機整數
    plt.subplot(2, 5, i + 1)
    plt.axis('off')
    plt.imshow(test_x[num], cmap='gray')
    demo = tensorflow.reshape(test_x[num], (1, 28, 28))
    y_pred = np.argmax(model.predict(demo))
    plt.title('標籤值:' + str(test_y[num]) + '\n預測值:' + str(y_pred))
# plt.show()

'''
保存模型
訓練好的模型可以用於加載後對新輸入數據進行預測,所以需要先進行保存已訓練模型
'''
#使用save_model保存完整模型
save_model(model, 'D:\\pythonProject\\mnistDemo\\number_model', save_format='pb')

mnistPredictDemo.py

import numpy as np
import tensorflow as tensorflow
import gzip
import os.path
from tensorflow.keras.models import load_model
# 預測
def predict(test_x, test_y):
    test_x, test_y = test_x, test_y
    '''
    五、模型評估
    需要先加載已訓練模型,然後用其預測新的數據,計算評估指標
    '''
    # 模型加載
    # 加載.pb模型文件
    global load_model
    # load_model = load_model('./saved_model')
    load_model = load_model('D:\\pythonProject\\mnistDemo\\number_model')
    load_model.summary()
    # make a prediction
    print("test_x")
    print(test_x)
    print(test_x.ndim)
    print(test_x.shape)

    demo = tensorflow.reshape(test_x, (1, 28, 28))
    input_data = np.array(demo)  # 準備你的輸入數據
    input_data = tensorflow.convert_to_tensor(input_data, dtype=tensorflow.float32)
    # test_x = tensorflow.cast(test_x / 255.0, tensorflow.float32)
    # test_y = tensorflow.cast(test_y, tensorflow.int16)
    predictValue = load_model.predict(input_data)
    print("predictValue")
    print(predictValue)
    y_pred = np.argmax(predictValue)
    print('標籤值:' + str(test_y) + '\n預測值:' + str(y_pred))
    return y_pred, test_y,
  
def load_data(data_folder):
    files = ["train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
             "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"]
    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    with gzip.open(paths[0], 'rb') as lbpath:
        train_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[1], 'rb') as imgpath:
        train_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(train_y), 28, 28)
    with gzip.open(paths[2], 'rb') as lbpath:
        test_y = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    with gzip.open(paths[3], 'rb') as imgpath:
        test_x = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(test_y), 28, 28)
    return (train_x, train_y), (test_x, test_y)

(train_x, train_y), (test_x, test_y) = load_data("mnistDataSet/")
print(train_x[0])
predict(train_x[0], train_y)

四、完整的java代碼

tensorflow 需要的java 版本對應表: https://github.com/tensorflow/java/#tensorflow-version-support

本工程使用環境爲

jdk版本:openjdk-21

pom依賴如下:


        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.6.0-SNAPSHOT</version>
        </dependency>

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-framework</artifactId>
            <version>0.6.0-SNAPSHOT</version>
        </dependency>
    </dependencies>

    <repositories>
        <repository>
            <id>tensorflow-snapshots</id>
            <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
            <snapshots>
                <enabled>true</enabled>
            </snapshots>
        </repository>
    </repositories>

數據集創建和解析類

MnistDataset.class

package org.example.tensorDemo.datasets.mnist;

import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.ImageBatchIterator;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.index.Indices;

import java.io.DataInputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;

import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;



public class MnistDataset {
    public static final int NUM_CLASSES = 10;

    private static final int TYPE_UBYTE = 0x08;

    /**
     * 訓練圖片字節類型的多維數組
     */
    private final ByteNdArray trainingImages;

    /**
     * 訓練標籤字節類型的多維數組
     */
    private final ByteNdArray trainingLabels;

    /**
     * 驗證圖片字節類型的多維數組
     */
    public final ByteNdArray validationImages;

    /**
     * 驗證標籤字節類型的多維數組
     */
    public final ByteNdArray validationLabels;

    /**
     * 測試圖片字節類型的多維數組
     */
    private final ByteNdArray testImages;

    /**
     * 測試標籤字節類型的多維數組
     */
    private final ByteNdArray testLabels;

    /**
     * 圖片的大小
     */
    private final long imageSize;


    /**
     * Mnist 數據集構造器
     */
    private MnistDataset(
            ByteNdArray trainingImages,
            ByteNdArray trainingLabels,
            ByteNdArray validationImages,
            ByteNdArray validationLabels,
            ByteNdArray testImages,
            ByteNdArray testLabels
    ) {
        this.trainingImages = trainingImages;
        this.trainingLabels = trainingLabels;
        this.validationImages = validationImages;
        this.validationLabels = validationLabels;
        this.testImages = testImages;
        this.testLabels = testLabels;
        //第一個圖像的形狀,並返回其尺寸大小。每一張圖片包含28X28個像素點 所以應該爲784
        this.imageSize = trainingImages.get(0).shape().size();
//        System.out.println("imageSize = " + imageSize);


//        System.out.println(String.format("train_x:%s,train_y:%s, test_x:%s, test_y:%s", trainingImages.shape(), trainingLabels.shape(), testImages.shape(), testLabels.shape()));
//        System.out.println("數據集的維度:" + trainingImages.rank());
//        System.out.println("數據集的形狀 = " + trainingImages.shape());
//        System.out.println("數據集的大小 = " + trainingImages.shape().get(0));
//        System.out.println("數據集 = ");
//        for (int i = 0; i < trainingImages.shape().get(0); i++) {
//            for (int j = 0; j < trainingImages.shape().get(1); j++) {
//                for (int k = 0; k < trainingImages.shape().get(2); k++) {
//                    System.out.print(trainingImages.getObject(i, j, k) + " ");
//                }
//                System.out.println();
//            }
//            System.out.println();
//        }
//        System.out.println("查看單個數據 = " + trainingImages.get(0));
//        for (int j = 0; j < trainingImages.shape().get(1); j++) {
//            for (int k = 0; k < trainingImages.shape().get(2); k++) {
//                System.out.print(trainingImages.getObject(0, j, k) + " ");
//            }
//            System.out.println();
//        }
//        System.out.println("查看單個數據大小 = " + trainingImages.get(0).size());
//        System.out.println("查看trainingImages三維數組下的第一個元素的第二個二維數組大小 = " + trainingImages.get(0).get(1).size());
//        System.out.println("查看trainingImages三維數組下的第一個元素的第7個二維數組的第8個元素 = " + trainingImages.getObject(0, 6, 8));
//        System.out.println("trainingLabels = " + trainingLabels.getObject(1));
    }

    /**
     * @param validationSize        驗證的數據
     * @param trainingImagesArchive 訓練圖片路徑
     * @param trainingLabelsArchive 訓練標籤路徑
     * @param testImagesArchive     測試圖片路徑
     * @param testLabelsArchive     測試標籤路徑
     */
    public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive,
                                      String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);

            if (validationSize > 0) {
                return new MnistDataset(
                        trainingImages.slice(sliceFrom(validationSize)),
                        trainingLabels.slice(sliceFrom(validationSize)),
                        trainingImages.slice(sliceTo(validationSize)),
                        trainingLabels.slice(sliceTo(validationSize)),
                        testImages,
                        testLabels
                );
            }
            return new MnistDataset(trainingImages, trainingLabels, null, null, testImages, testLabels);

        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }

    /**
     * @param trainingImagesArchive 訓練圖片路徑
     * @param trainingLabelsArchive 訓練標籤路徑
     * @param testImagesArchive     測試圖片路徑
     * @param testLabelsArchive     測試標籤路徑
     */
    public static MnistDataset getOneValidationImage(int index, String trainingImagesArchive, String trainingLabelsArchive,
                                                     String testImagesArchive, String testLabelsArchive) {
        try {
            ByteNdArray trainingImages = readArchive(trainingImagesArchive);
            ByteNdArray trainingLabels = readArchive(trainingLabelsArchive);
            ByteNdArray testImages = readArchive(testImagesArchive);
            ByteNdArray testLabels = readArchive(testLabelsArchive);
            trainingImages.slice(sliceFrom(0));
            trainingLabels.slice(sliceTo(0));
            // 切片操作
            Index range = Indices.range(index, index + 1);// 切片的起始和結束索引
            ByteNdArray validationImage = trainingImages.slice(range); // 執行切片操作
            ByteNdArray validationLable = trainingLabels.slice(range); // 執行切片操作


            if (index >= 0) {
                return new MnistDataset(
                        trainingImages,
                        trainingLabels,
                        validationImage,
                        validationLable,
                        testImages,
                        testLabels
                );
            } else {
                return null;
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }

    private static ByteNdArray readArchive(String archiveName) throws IOException {
        System.out.println("archiveName = " + archiveName);
        DataInputStream archiveStream = new DataInputStream(
                //new GZIPInputStream(new java.io.FileInputStream("src/main/resources/"+archiveName))
                new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName))
        );
        //todo 不知道怎麼讀取和實際的內部結構
        archiveStream.readShort(); // first two bytes are always 0
        byte magic = archiveStream.readByte();
        if (magic != TYPE_UBYTE) {
            throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive");
        }
        int numDims = archiveStream.readByte();
        long[] dimSizes = new long[numDims];
        int size = 1;  // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE
        for (int i = 0; i < dimSizes.length; ++i) {
            dimSizes[i] = archiveStream.readInt();
            size *= dimSizes[i];
        }
        byte[] bytes = new byte[size];
        archiveStream.readFully(bytes);
        return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, false, false));
    }

    public Iterable<ImageBatch> trainingBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, trainingImages, trainingLabels);
    }

    public Iterable<ImageBatch> validationBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels);
    }

    public Iterable<ImageBatch> testBatches(int batchSize) {
        return () -> new ImageBatchIterator(batchSize, testImages, testLabels);
    }

    public ImageBatch testBatch() {
        return new ImageBatch(testImages, testLabels);
    }

    public long imageSize() {
        return imageSize;
    }

    public long numTrainingExamples() {
        return trainingLabels.shape().size(0);
    }

    public long numTestingExamples() {
        return testLabels.shape().size(0);
    }

    public long numValidationExamples() {
        return validationLabels.shape().size(0);
    }

    public ByteNdArray getValidationImages() {
        return validationImages;
    }

    public ByteNdArray getValidationLabels() {
        return validationLabels;
    }
}

SimpleMnist.class

package org.example.tensorDemo.dense;
import org.example.tensorDemo.datasets.ImageBatch;
import org.example.tensorDemo.datasets.mnist.MnistDataset;
import org.tensorflow.*;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.nn.Softmax;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;

public class SimpleMnist implements Runnable {
    private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
    private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
    private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz";
    private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz";

    public static void main(String[] args) {
        //加載數據集
//        MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
//                TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
        MnistDataset validationDataset = MnistDataset.getOneValidationImage(3, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
                TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
        //創建了一個名爲graph的圖形對象。
        try (Graph graph = new Graph()) {
            SimpleMnist mnist = new SimpleMnist(graph, validationDataset);
            mnist.run();//構建和訓練模型
            mnist.loadModel("D:\\ai\\ai-demo");
        }
    }

    @Override
    public void run() {
        Ops tf = Ops.create(graph);
        // Create placeholders and variables, which should fit batches of an unknown number of images
        //創建佔位符和變量,這些佔位符和變量應適合未知數量的圖像批次
        Placeholder<TFloat32> images = tf.placeholder(TFloat32.class);
        Placeholder<TFloat32> labels = tf.placeholder(TFloat32.class);

        // Create weights with an initial value of 0
        // 創建初始值爲 0 的權重
        Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
        Variable<TFloat32> weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));

        // Create biases with an initial value of 0
        //創建初始值爲 0 的偏置
        Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
        Variable<TFloat32> biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));

        // Predict the class of each image in the batch and compute the loss
        //使用 TensorFlow 的 tf.linalg.matMul 函數計算圖像矩陣 images 和權重矩陣 weights 的矩陣乘法,並加上偏置項 biases。
        //wx+b
        MatMul<TFloat32> matMul = tf.linalg.matMul(images, weights);
        Add<TFloat32> add = tf.math.add(matMul, biases);

        //Softmax 是一個常用的激活函數,它將輸入轉換爲表示概率分佈的輸出。對於輸入向量中的每個元素,Softmax 函數會計算指數,
        //並對所有元素求和,然後將每個元素的指數除以總和,最終得到一個概率分佈。這通常用於多分類問題,以輸出每個類別的概率
        //激活函數 
        Softmax<TFloat32> softmax = tf.nn.softmax(add);

        // 創建一個計算交叉熵的Mean對象
        //損失函數
        Mean<TFloat32> crossEntropy =
                tf.math.mean(  // 計算張量的平均值
                        tf.math.neg(  // 計算張量的負值
                                tf.reduceSum(  // 計算張量的和
                                        tf.math.mul(labels, tf.math.log(softmax)),  //計算標籤和softmax預測的對數乘積
                                        tf.array(1)  // 在指定軸上求和
                                )
                        ),
                        tf.array(0)  // 在指定軸上求平均值
                );

        // Back-propagate gradients to variables for training
        //使用梯度下降優化器來最小化交叉熵損失函數。首先,創建了一個梯度下降優化器 optimizer,然後使用該優化器來最小化交叉熵損失函數 crossEntropy。
        //梯度下降 https://www.cnblogs.com/guoyaohua/p/8542554.html
        Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
        Op minimize = optimizer.minimize(crossEntropy);

        // Compute the accuracy of the model
        //使用 argMax 函數找出在給定軸上張量中最大值的索引,
        Operand<TInt64> predicted = tf.math.argMax(softmax, tf.constant(1));
        Operand<TInt64> expected = tf.math.argMax(labels, tf.constant(1));
        //使用 equal 函數比較模型預測的標籤和實際標籤是否相等,再用 cast 函數將布爾值轉換爲浮點數,最後使用 mean 函數計算準確率。
        Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));

        // Run the graph
        try (Session session = new Session(graph)) {
            // Train the model
            for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
                try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
                     TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
                    System.out.println("batchImages = " + batchImages.shape());
                    System.out.println("batchLabels = " + batchLabels.shape());
                    // 創建會話運行器
                    session.runner()
                            // 添加要最小化的目標
                            .addTarget(minimize)
                            // 通過feed方法將圖像數據輸入到模型中
                            .feed(images.asOutput(), batchImages)
                            // 通過feed方法將標籤數據輸入到模型中
                            .feed(labels.asOutput(), batchLabels)
                            // 運行會話
                            .run();
                }
            }

            // Test the model
            ImageBatch testBatch = dataset.testBatch();
            try (TFloat32 testImages = preprocessImages(testBatch.images());
                 TFloat32 testLabels = preprocessLabels(testBatch.labels());
                 // 定義一個TFloat32類型的變量accuracyValue,用於存儲計算得到的準確率值
                 TFloat32 accuracyValue = (TFloat32) session.runner()
                         // 從會話中獲取準確率值
                         .fetch(accuracy)
                         .fetch(predicted)
                         .fetch(expected)
                         // 將images作爲輸入,testImages作爲數據進行餵養
                         .feed(images.asOutput(), testImages)
                         // 將labels作爲輸入,testLabels作爲數據進行餵養
                         .feed(labels.asOutput(), testLabels)
                         // 運行會話並獲取結果
                         .run()
                         // 獲取第一個結果並存儲在accuracyValue中
                         .get(0)) {
                System.out.println("Accuracy: " + accuracyValue.getFloat());
            }
            // 保存模型
            SavedModelBundle.Exporter exporter = SavedModelBundle.exporter("D:\\ai\\ai-demo").withSession(session);
            Signature.Builder builder = Signature.builder();
            builder.input("images", images);
            builder.input("labels", labels);
            builder.output("accuracy", accuracy);
            builder.output("expected", expected);
            builder.output("predicted", predicted);
            Signature signature = builder.build();
            SessionFunction sessionFunction = SessionFunction.create(signature, session);
            exporter.withFunction(sessionFunction);
            exporter.export();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }

    private static final int VALIDATION_SIZE = 5;
    private static final int TRAINING_BATCH_SIZE = 100;
    private static final float LEARNING_RATE = 0.2f;

    private static TFloat32 preprocessImages(ByteNdArray rawImages) {
        Ops tf = Ops.create();
        // Flatten images in a single dimension and normalize their pixels as floats.
        long imageSize = rawImages.get(0).shape().size();
        return tf.math.div(
                tf.reshape(
                        tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
                        tf.array(-1L, imageSize)
                ),
                tf.constant(255.0f)
        ).asTensor();
    }

    private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
        Ops tf = Ops.create();
        // Map labels to one hot vectors where only the expected predictions as a value of 1.0
        return tf.oneHot(
                tf.constant(rawLabels),
                tf.constant(MnistDataset.NUM_CLASSES),
                tf.constant(1.0f),
                tf.constant(0.0f)
        ).asTensor();
    }

    private final Graph graph;
    private final MnistDataset dataset;

    private SimpleMnist(Graph graph, MnistDataset dataset) {
        this.graph = graph;
        this.dataset = dataset;
    }

    public void loadModel(String exportDir) {
        // load saved model
        SavedModelBundle model = SavedModelBundle.load(exportDir, "serve");
        try {
            printSignature(model);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        ByteNdArray validationImages = dataset.getValidationImages();
        ByteNdArray validationLabels = dataset.getValidationLabels();
        TFloat32 testImages = preprocessImages(validationImages);
        System.out.println("testImages = " + testImages.shape());
        TFloat32 testLabels = preprocessLabels(validationLabels);
        System.out.println("testLabels = " + testLabels.shape());
        Result run = model.session().runner()
                .feed("Placeholder:0", testImages)
                .feed("Placeholder_1:0", testLabels)
                .fetch("ArgMax:0")
                .fetch("ArgMax_1:0")
                .fetch("Mean_1:0")
                .run();
        // 處理輸出
        Optional<Tensor> tensor1 = run.get("ArgMax:0");
        Optional<Tensor> tensor2 = run.get("ArgMax_1:0");
        Optional<Tensor> tensor3 = run.get("Mean_1:0");
        TInt64 predicted = (TInt64) tensor1.get();
        Long predictedValue = predicted.getObject(0);
        System.out.println("predictedValue = " + predictedValue);
        TInt64 expected = (TInt64) tensor2.get();
        Long expectedValue = expected.getObject(0);
        System.out.println("expectedValue = " + expectedValue);
        TFloat32 accuracy = (TFloat32) tensor3.get();
        System.out.println("accuracy = " + accuracy.getFloat());
    }

    private static void printSignature(SavedModelBundle model) throws Exception {
        MetaGraphDef m = model.metaGraphDef();
        SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
        int numInputs = sig.getInputsCount();
        int i = 1;
        System.out.println("MODEL SIGNATURE");
        System.out.println("Inputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
        }
        int numOutputs = sig.getOutputsCount();
        i = 1;
        System.out.println("Outputs:");
        for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
            TensorInfo t = entry.getValue();
            System.out.printf(
                    "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                    i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
        }
        System.out.println("-----------------------------------------------");
    }
}

五、最後兩套代碼運行結果





 





 

六、待完善點

1、這裏並沒有對提供web服務輸入圖片以及圖片數據二值話等進行處理。有興趣的小夥伴可以自己進行嘗試

2、並沒有使用卷積神經網絡等,只是用了wx+b和激活函數進行跳躍,以及階梯下降算法和交叉熵

3、沒有進行更多層級的設計等

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章