Tensorflow在手機端的部署——官網Android工程源碼分析之TensorFlowYoloDetector.java (1)

文章分析下tensorflow提供的官方Android工程的源碼分析,後續涉及更改代碼,因此有必要對其做深入理解。

首先工程文件路徑爲:tensorflow-master\tensorflow\examples\android

由於這個android工程中實現了目標檢測,風格遷移,語音,圖像分類四個功能,其中目標檢測中有用到yolo檢測,有用到ssd-mobilenet v1檢測,還有就是用到multi-box做檢測。本文將針對yolo做檢測需要用到的TensorFlowYoloDetector.java代碼部分進行詳細講解。

 

 

其中yolo v2在20類訓練後得到的output上得到的結果按下圖次序進行排列:

即【第一個框:x,y,w,h,confidence,class1,……,class20】【第二個框:x,y,w,h,confidence,class1,……,class20】……【第13x13x5個框:x,y,w,h,confidence,class1,……,class20】

 

TensorFlowYoloDetector.java部分代碼如下:

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.demo;                //包的名字,可隨意更改

import android.content.res.AssetManager;   //assets文件夾下的文件不會被映射到R.java中,訪問的時候需要AssetManager類
import android.graphics.Bitmap;             //導入安卓系統的圖像處理類Bitmap ,以便進行圖像剪切、旋轉、縮放等操作,並可以指定格式保存圖像文件
import android.graphics.RectF;             //這個類包含一個矩形的四個單精度浮點座標。矩形通過上下左右4個邊的座標來表示一個矩形
import android.os.Trace;                // Android SDK中提供了`android.os.Trace#beginSection`和`android.os.Trace#endSection` 這兩個接,我們可以在代碼中插入這些代碼來分析某個特定的過程:
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;   //tf針對安卓封裝的inference類
import org.tensorflow.demo.env.Logger;      //定義的一個類用於報文生成便於分析
import org.tensorflow.demo.env.SplitTimer;  //定義的一個類用於計算CPU時間

/** An object detector that uses TF and a YOLO model to detect objects. */
public class TensorFlowYoloDetector implements Classifier {         //定義TensorFlowYoloDetector這個用TF版yolo的類,繼承了Classifier
  private static final Logger LOGGER = new Logger();                //實例化一個報文對象

  // Only return this many results with at least this confidence.
  private static final int MAX_RESULTS = 5;           //根據概率刷選出的前5個結果

  private static final int NUM_CLASSES = 80;          //模型訓練的類別數(根據實際更改,此處由於在coco上訓練的80個類別)

  private static final int NUM_BOXES_PER_BLOCK = 5;    //yolo 2模型中採用錨點機制,因此每個特徵圖上的cell會預測5個錨點框

  // TODO(andrewharp): allow loading anchors and classes
  // from files.
  private static final double[] ANCHORS = {    //double型數組中存放5個錨點尺寸(是在數據集中聚類得到的)
    1.08, 1.19,
    3.42, 4.41,
    6.63, 11.38,
    9.42, 5.11,
    16.62, 10.52
  };

  private static final String[] LABELS_VOC = {       //字符串數組,用於標籤索引,根據實際訓練填寫,此處爲VOC的20類
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor"
  };

  private static final String[] LABELS = {  //標籤數組,此處爲coco的80類
          "person",
          "bicycle",
          "car",
          "motorbike",
          "aeroplane",
          "bus",
          "train",
          "truck",
          "boat",
          "traffic light",
          "fire hydrant",
          "stop sign",
          "parking meter",
          "bench",
          "bird",
          "cat",
          "dog",
          "horse",
          "sheep",
          "cow",
          "elephant",
          "bear",
          "zebra",
          "giraffe",
          "backpack",
          "umbrella",
          "handbag",
          "tie",
          "suitcase",
          "frisbee",
          "skis",
          "snowboard",
          "sports ball",
          "kite",
          "baseball bat",
          "baseball glove",
          "skateboard",
          "surfboard",
          "tennis racket",
          "bottle",
          "wine glass",
          "cup",
          "fork",
          "knife",
          "spoon",
          "bowl",
          "banana",
          "apple",
          "sandwich",
          "orange",
          "broccoli",
          "carrot",
          "hot dog",
          "pizza",
          "donut",
          "cake",
          "chair",
          "sofa",
          "pottedplant",
          "bed",
          "diningtable",
          "toilet",
          "tvmonitor",
          "laptop",
          "mouse",
          "remote",
          "keyboard",
          "cell phone",
          "microwave",
          "oven",
          "toaster",
          "sink",
          "refrigerator",
          "book",
          "clock",
          "vase",
          "scissors",
          "teddy bear",
          "hair drier",
          "toothbrush"
  };

  // Config values.
  private String inputName;    //輸入名
  private int inputSize;      //輸入尺寸

  // Pre-allocated buffers.   //預先分配buffer
  private int[] intValues;    //整型數組(傳入網絡圖像尺寸長x寬) 像素位置
  private float[] floatValues;    //浮點型數組 (傳入網絡圖像尺寸長x寬x通道)  各通道像素值
  private String[] outputNames;  //輸出名

  private int blockSize;     //網絡縮放大小。yolo中爲32

  private boolean logStats = false;    //log狀態

  private TensorFlowInferenceInterface inferenceInterface;    //推理類 對象實例化

  /** Initializes a native TensorFlow session for classifying images. */  //初始化一個本地TF會話用作圖像分類
  public static Classifier create(
      final AssetManager assetManager,  //資源管理類 對象實例
      final String modelFilename,      //模型名
      final int inputSize,             //輸入尺寸
      final String inputName,          //輸入名
      final String outputName,          //輸出名
      final int blockSize) {           //特徵圖block大小
    TensorFlowYoloDetector d = new TensorFlowYoloDetector();   //TF yolo檢測類對象實例化
    d.inputName = inputName;
    d.inputSize = inputSize;

    // Pre-allocate buffers.
    d.outputNames = outputName.split(",");     //對outputName字符串按‘,’分割後存入outputNames數組中
    d.intValues = new int[inputSize * inputSize];      //輸入尺寸x輸入尺寸 (這裏yolo2的話應該爲416x416)
    d.floatValues = new float[inputSize * inputSize * 3];  //輸入尺寸x輸入尺寸x3 (這裏yolo2的話應該爲416x416x3)
    d.blockSize = blockSize;     //網絡縮放大小

    d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);   //inference類實例化,並傳入資源管理器類對象和模型文件名

    return d;    //返回一個TF Yolo檢測器對象 ,並且開啓以一個TF session,從assets中讀取了模型文件
  }

  private TensorFlowYoloDetector() {}

  private float expit(final float x) {               //定義了一個sigmoid(x)函數
    return (float) (1. / (1. + Math.exp(-x)));
  }

  private void softmax(final float[] vals) {    //定義了一個softmax函數。傳入一個float數組,返回做完sofamax後的數組
    float max = Float.NEGATIVE_INFINITY;  //初始化最大值爲負無窮
    for (final float val : vals) {   //對vals數組中的值進行遍歷,尋找最大值max
      max = Math.max(max, val);
    }
    float sum = 0.0f;
    for (int i = 0; i < vals.length; ++i) {
      vals[i] = (float) Math.exp(vals[i] - max);
      sum += vals[i];
    }
    for (int i = 0; i < vals.length; ++i) {      //對數組中的值進行歸一化
      vals[i] = vals[i] / sum;
    }
  }

  @Override      //函數重寫, recognizeImage在Classifier中沒有具體函數操作內容,以下進行重寫
  public List<Recognition> recognizeImage(final Bitmap bitmap) {   //返回的是一個list結果,list中的元素爲Recognition格式
    final SplitTimer timer = new SplitTimer("recognizeImage");   //識別圖像計算器 實例化

    // Log this method so that it can be analyzed with systrace.   //報文分析用
    Trace.beginSection("recognizeImage");

    Trace.beginSection("preprocessBitmap");          //預處理圖像過程開始
    // Preprocess the image data from 0-255 int to normalized float based    //以下是將輸入圖像數據進行預處理,將0-255值域歸一化到浮點數0-1
    // on the provided parameters.
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());  //傳入像素位置數組intValues,將bitmap彩色圖像對應位置處的像素值賦值,那麼intValues[i]中應是24bit的數

    for (int i = 0; i < intValues.length; ++i) {    //遍歷圖像所有像素位置點,讀取像素值
      floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;   //給圖像浮點型數組一一賦歸一化到0-1後的值
      floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
      floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
    }
    Trace.endSection(); // preprocessBitmap        //預處理圖像過程結束

    // Copy the input data into TensorFlow.       //將預處理後的圖像數據傳入網絡,即feed
    Trace.beginSection("feed");     //feed過程開啓
    inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);  //傳入;輸入網絡的tensor名,float buffer(即上面預處理後浮點數組的圖像),...longs參數表述輸入tensor的尺寸[1,w,h,3],最後底層會將浮點數組中的值按tensor尺寸進行重新存放後送入網絡
    Trace.endSection();                           //feed過程結束

    timer.endSplit("ready for inference");   //開始做inference計時

    // Run the inference call.                     //運行inference調用
    Trace.beginSection("run");       //開始run
    inferenceInterface.run(outputNames, logStats);   // 需要指定網絡輸出tensor名數組(可以是多個輸出名,用‘,’分隔)
    Trace.endSection();                            //run結束

    timer.endSplit("ran inference");        //做inference計時結束

    // Copy the output Tensor back into the output array.       //將輸出的tensor拷貝到輸出數組中
    Trace.beginSection("fetch");
    final int gridWidth = bitmap.getWidth() / blockSize;         // 網格寬度=輸入網絡圖像寬 / 網絡縮放大小 (yolo 2的是 416/32=13)
    final int gridHeight = bitmap.getHeight() / blockSize;      //網格高度=輸入網絡圖像高/ 網絡縮放大小  (yolo 2的是 416/32=13)
    final float[] output =                                       //定義輸出結果的維度爲output:[13x13x(80+5)x5] 根據實際訓練進行更改
        new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
    inferenceInterface.fetch(outputNames[0], output);       //將 outputNames[0]中的結果按照output的維度進行賦值填充
    Trace.endSection();

    // Find the best detections.                 //尋找最好的檢測結果,將隊列中的元素按照置信度大小從大到小排列
    final PriorityQueue<Recognition> pq =        //優先隊列 對象實例化,隊列中的元素爲Recognition格式
        new PriorityQueue<Recognition>(
            1,                    //初始容器大小;
            new Comparator<Recognition>() {
              @Override
              public int compare(final Recognition lhs, final Recognition rhs) {
                // Intentionally reversed to put high confidence at the head of the queue.   將置信度高的放在隊列最前面
                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
              }
            });

    for (int y = 0; y < gridHeight; ++y) {       //對特徵圖grid高進行遍歷,yolo 2中爲13  ,先行後列進行遍歷13x13大小的特徵圖
      for (int x = 0; x < gridWidth; ++x) {       //對特徵圖grid寬進行遍歷,yolo 2中爲13
        for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {      //對特徵圖每個pixcel上的5個預測框進行遍歷,尋找與GT IOU最大的預測框
          final int offset =   // 預測結果存放形式是: 第一個框先4個座標,後1個置信度,第二個框……如此遍歷13x13x5個框,每個框5個信息
              (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
                  + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
                  + (NUM_CLASSES + 5) * b;
          //說明下:output中推理得到的位置信息需要進行sigmoid後得到相對所在cell的偏移值(歸一化到0-1), blockSize表示最後特徵圖一個點對應原圖區域大小
          final float xPos = (x + expit(output[offset + 0])) * blockSize;   //預測框在原圖實際中心橫座標xPos=(相對13x13的特徵圖左上角的橫向偏移座標)*32
          final float yPos = (y + expit(output[offset + 1])) * blockSize;   //預測框在原圖實際中心縱座標yPos=(相對13x13的特徵圖左上角的縱向偏移座標)*32

          final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize; //預測框在原圖實際寬 w=(相對特徵圖pixcel寬* 預先聚類錨點寬比例)*32
          final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize; //預測框在原圖實際高 h=(相對特徵圖pixcel高* 預先聚類錨點高比例)*32

          final RectF rect =  //RectF 對象實例 rect(xmin,ymin,xmax,ymax)
              new RectF(
                  Math.max(0, xPos - w / 2),                      //實際框xmin
                  Math.max(0, yPos - h / 2),                      //實際框ymin
                  Math.min(bitmap.getWidth() - 1, xPos + w / 2),     //實際框xmax
                  Math.min(bitmap.getHeight() - 1, yPos + h / 2));   //實際框ymax
          final float confidence = expit(output[offset + 4]); //置信度歸一化後的值

          int detectedClass = -1;   //定義一個檢測到的類ID
          float maxClass = 0;       //定義一個概率值最大類對應的概率

          final float[] classes = new float[NUM_CLASSES];   //定義一個float類型的數組,長度是類別總數;
          for (int c = 0; c < NUM_CLASSES; ++c) {
            classes[c] = output[offset + 5 + c];  //將output中對應位置處的類別概率值賦給classes數組;
          }
          softmax(classes);    //對NUM_CLASSES個分佈概率進行softmax歸一化;

          for (int c = 0; c < NUM_CLASSES; ++c) { //尋找概率值最大的類ID detectedClass以及對應的概率值maxClass;
            if (classes[c] > maxClass) {
              detectedClass = c;
              maxClass = classes[c];
            }
          }

          final float confidenceInClass = maxClass * confidence; //屬於某類別概率confidenceInClass=類別概率*存在目標的置信度;
          if (confidenceInClass > 0.01) {  //如果某類概率>0.01,則打印報文
            LOGGER.i(
                "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect);
            pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect));  //將框對應的offset,框所屬類別名,類別概率,框座標 加入Recognition;
          }
        }
      }
    }
    timer.endSplit("decoded results");  //將output的結果進行解碼輸出;

    final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();  //用於將識別解碼的結果放入ArrayList;
    for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
      recognitions.add(pq.poll());  //將pq中的識別結果概率從大到小排列,選取最多5個結果保存
    }
    Trace.endSection(); // "recognizeImage"

    timer.endSplit("processed results");  //結果處理結束;

    return recognitions;  //將概率結果值較大的識別框結果列表recognitions返回;
  }

  @Override
  public void enableStatLogging(final boolean logStats) {
    this.logStats = logStats;
  }

  @Override
  public String getStatString() {
    return inferenceInterface.getStatString();
  }

  @Override
  public void close() {
    inferenceInterface.close();  //關閉feed,fetch,session
  }
}

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章