將Pytorch模型部署到Android端

最後效果:
在這裏插入圖片描述
準備:

  • pytorch1.4(pytorch環境至少要在1.3以上,當前最新版本1.4)
  • 已經訓練好的pytorch模型
  • Jetpack組件:CameraX(這個用來調用相機的)

如有需要,可以先看看我這兩篇博文:
如果pytorch環境不滿足,進行pytorch環境升級:win10+pytorch1.4+cuda10.1安裝:從顯卡驅動開始
Jetpack組件:CameraX,使用前一定要先了解:Jetpack CameraX實踐,預覽(preview)及分析(analysis)

模型轉化

# pytorch環境中
model_pth = os.path.join(MODEL_PATH, 'resnet18.pth') # resnet18模型的參數文件
mobile_pt = os.path.join(MODEL_PATH, 'resnet18.pt')  # 將resnet18模型保存爲Android可以調用的文件

model = make_model('resnet18') # 搭建網絡
model.load_state_dict(torch.load(model_pth)) # 加載參數
model.eval() # 模型設爲評估模式

# 1張3通道224*224的圖片
input_tensor = torch.rand(1, 3, 224, 224) # 設定輸入數據格式

mobile = torch.jit.trace(model, input_tensor) # 模型轉化
mobile.save(mobile_pt) # 保存文件

注:這樣就完成了模型的轉化,得到resnet18.pt文件

Android 設置CameraX:實現預覽

添加依賴:

// CameraX core library using the camera2 implementation
def camerax_version = "1.0.0-beta01"
implementation "androidx.camera:camera-camera2:${camerax_version}"
implementation "androidx.camera:camera-view:1.0.0-alpha08"
implementation "androidx.camera:camera-extensions:1.0.0-alpha08"
implementation "androidx.camera:camera-lifecycle:${camerax_version}"

//pytorch
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'

申請相機權限及跳轉:
在AndroidManifest.xml中添加權限:<uses-permission android:name="android.permission.CAMERA" />,然後跳轉CameraX頁面之前進行動態權限申請(也可以將動態權限申請放在CameraXFragment中,調用相機時再申請):

package com.example.gca.leftFragment

import android.Manifest
......
import kotlinx.android.synthetic.main.left_fragment.*

private const val REQUEST_CODE_PERMISSIONS = 10 // 權限標識符
private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA) // 相機權限

class LeftFragment : Fragment() {

    override fun onCreateView(
        inflater: LayoutInflater, container: ViewGroup?,
        savedInstanceState: Bundle?
    ): View? {
        return inflater.inflate(R.layout.left_fragment, container, false)
    }

    override fun onActivityCreated(savedInstanceState: Bundle?) {
        super.onActivityCreated(savedInstanceState)

        // 去到CameraXFragment頁面
        buttonCameraX.setOnClickListener {
            // 檢查相機權限
            if (allPermissionsGranted()) {
            	// 這裏是用導航組件(Navigation)進行跳轉的
                Navigation.findNavController(it).navigate(R.id.action_leftFragment_to_cameraXFragment)
            } else {
                requestPermissions(REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS)
            }
        }
    }

    // 請求權限結果回調
    override fun onRequestPermissionsResult(
        requestCode: Int, permissions: Array<String>, grantResults: IntArray
    ) {
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                // 權限通過,進行跳轉
                Navigation.findNavController(requireView()).navigate(R.id.action_leftFragment_to_cameraXFragment)
            } else {
                Toast.makeText(
                    requireContext(),
                    "Permissions not granted by the user.",
                    Toast.LENGTH_SHORT
                ).show()
            }
        }
    }

    // 檢查權限
    private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all {
        ContextCompat.checkSelfPermission(
            requireContext(), it
        ) == PackageManager.PERMISSION_GRANTED
    }
}

新建一個fragment和佈局文件(用來放置相機的),佈局如下(fragment_camera_x.xml):

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".leftFragment.cameraXFragment.CameraXFragment">

    <androidx.camera.view.PreviewView
        android:id="@+id/previewView"
        android:layout_width="wrap_content"
        android:layout_height="0dp"
        android:layout_marginBottom="16dp"
        app:layout_constraintBottom_toTopOf="@+id/textView2"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.0"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

    <TextView
        android:id="@+id/textView2"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginBottom="32dp"
        android:text="TextView"
        android:textSize="30sp"
        app:layout_constraintBottom_toTopOf="@+id/textView3"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent" />

    <TextView
        android:id="@+id/textView3"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginBottom="32dp"
        android:text="TextView"
        android:textSize="30sp"
        app:layout_constraintBottom_toTopOf="@+id/textView4"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent" />

    <TextView
        android:id="@+id/textView4"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginBottom="32dp"
        android:text="TextView"
        android:textSize="30sp"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent" />

</androidx.constraintlayout.widget.ConstraintLayout>

fragment設置(CameraXFragment.kt):

package com.example.gca.leftFragment.cameraXFragment

import android.os.Bundle
......
import java.util.concurrent.Executors

class CameraXFragment : Fragment(), CameraXConfig.Provider {
    override fun getCameraXConfig(): CameraXConfig {
        return Camera2Config.defaultConfig()
    }

    private lateinit var cameraProviderFuture: ListenableFuture<ProcessCameraProvider> // 相機的控制者
    private lateinit var imagePreview: Preview // 圖像預覽
    private lateinit var cameraPreviewView: PreviewView // 顯示相機的控件

    override fun onCreateView(
        inflater: LayoutInflater, container: ViewGroup?,
        savedInstanceState: Bundle?
    ): View? {
        // Inflate the layout for this fragment
        return inflater.inflate(R.layout.fragment_camera_x, container, false)
    }

    override fun onActivityCreated(savedInstanceState: Bundle?) {
        super.onActivityCreated(savedInstanceState)

        cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相機控制權
        cameraPreviewView = previewView // 顯示相機控件

        // 加載相機
        cameraPreviewView.post { startCamera() }
    }

    private fun startCamera() {
        // 預覽
        imagePreview = Preview.Builder().apply {
            setTargetAspectRatio(AspectRatio.RATIO_16_9)
            setTargetRotation(previewView.display.rotation)
        }.build()
        imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)

        // 綁定
        val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()
        cameraProviderFuture.addListener(Runnable {
            val cameraProvider = cameraProviderFuture.get()
            cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview)
        }, ContextCompat.getMainExecutor(requireContext()))
    }
}

注:到這一步,可以運行項目,已經可以調用相機進行預覽了,如果不行,參考Jetpack CameraX實踐,預覽(preview)及分析(analysis),應該是哪裏漏掉了什麼

pytorch模型部署

添加資源,將我們轉化的模型resnet18.pt複製到assets文件夾下(如果你沒有assets文件夾,參考:https://blog.csdn.net/y_dd6011

添加兩個常量:

const val MODEL_NAME = "resnet18.pt" // 神經網絡
val IMAGE_CLASSIFICATION = arrayOf(  // 這個就是你的神經網絡能夠識別的圖片種類數目
    "tench, Tinca tinca",
    ......
    "goldfish, Carassius auratus",
)

新建一個kotlin類(Unit.kt):(用來獲取神經網絡的絕對地址)

package com.example.gca.unit

import android.content.Context
import android.util.Log
import java.io.File
import java.io.FileOutputStream
import java.io.IOException

object Unit {
    fun assetFilePath(context: Context, assetName: String): String? {
        val file = File(context.filesDir, assetName)
        try {
            context.assets.open(assetName).use { `is` ->
                FileOutputStream(file).use { os ->
                    val buffer = ByteArray(4 * 1024)
                    while (true) {
                        val length = `is`.read(buffer)
                        if (length <= 0)
                            break
                        os.write(buffer, 0, length)
                    }
                    os.flush()
                    os.close()
                }
                return file.absolutePath
            }
        } catch (e: IOException) {
            Log.e("pytorch", "Error process asset $assetName to file path")
        }
        return null
    }
}

再新建一個kotlin類(ImageClassificationResult.kt):(圖像分析之後的回調結果,這裏封裝成類)

package com.example.gca.unit

import com.example.gca.IMAGE_CLASSIFICATION

class ImageClassificationResult(private val index: Int, private val value: Float) {

    fun getImageClassification() = IMAGE_CLASSIFICATION[index]
    fun getGarbageIndex() = index
    fun getGarbageValue() = value
}

最後一步,給相機添加圖像分析器(CameraXFragment.kt完整代碼如下):

package com.example.gca.leftFragment.cameraXFragment

import android.os.Bundle
import android.util.Log
import android.util.Size
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import androidx.camera.camera2.Camera2Config
import androidx.camera.core.*
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView
import androidx.core.content.ContextCompat
import androidx.fragment.app.Fragment
import com.example.gca.MODEL_NAME
import com.example.gca.R
import com.example.gca.unit.ImageClassificationResult
import com.example.gca.unit.Unit.assetFilePath
import com.google.common.util.concurrent.ListenableFuture
import kotlinx.android.synthetic.main.fragment_camera_x.*
import kotlinx.coroutines.MainScope
import kotlinx.coroutines.launch
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.Tensor
import org.pytorch.torchvision.TensorImageUtils
import java.nio.ByteBuffer
import java.util.concurrent.Executors

typealias ResultListener = (result: ImageClassificationResult) -> Unit // 圖像分析器的返回結果類型,typealias 是取別名

class CameraXFragment : Fragment(), CameraXConfig.Provider {
    override fun getCameraXConfig(): CameraXConfig {
        return Camera2Config.defaultConfig()
    }

    private lateinit var cameraProviderFuture: ListenableFuture<ProcessCameraProvider> // 相機的控制者
    private lateinit var imagePreview: Preview // 圖像預覽
    private lateinit var imageAnalysis: ImageAnalysis // 圖像分析
    private val executor = Executors.newSingleThreadExecutor() // 後臺線程
    private lateinit var cameraPreviewView: PreviewView // 顯示相機的控件
    private lateinit var module: Module // 模型

    override fun onCreateView(
        inflater: LayoutInflater, container: ViewGroup?,
        savedInstanceState: Bundle?
    ): View? {
        // Inflate the layout for this fragment
        return inflater.inflate(R.layout.fragment_camera_x, container, false)
    }

    override fun onActivityCreated(savedInstanceState: Bundle?) {
        super.onActivityCreated(savedInstanceState)

        cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相機控制權
        cameraPreviewView = previewView // 顯示相機控件

        // 加載圖片識別模型
        try {
            val modulePath = assetFilePath(requireContext(), MODEL_NAME)
            module = Module.load(modulePath)
        } catch (e: Exception) {
            Log.e(CameraXFragment::class.java.simpleName, e.toString())
        }

        // 加載相機
        cameraPreviewView.post { startCamera() }
    }

    private fun startCamera() {
        // 預覽
        imagePreview = Preview.Builder().apply {
            setTargetAspectRatio(AspectRatio.RATIO_16_9)
            setTargetRotation(previewView.display.rotation)
        }.build()
        imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)

        // 分析
        imageAnalysis = ImageAnalysis.Builder().apply {
            setImageQueueDepth(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
            setTargetResolution(Size(224, 224))
        }.build()
        imageAnalysis.setAnalyzer(executor, ImageClassificationAnalyzer(module) {
            MainScope().launch {
                textView2.text = it.getImageClassification()
                textView3.text = it.getGarbageIndex().toString()
                textView4.text = it.getGarbageValue().toString()
            }
            Log.v(CameraXFragment::class.java.simpleName, it.toString())
        })

        // 綁定
        val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()
        cameraProviderFuture.addListener(Runnable {
            val cameraProvider = cameraProviderFuture.get()
            cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview, imageAnalysis)
        }, ContextCompat.getMainExecutor(requireContext()))
    }

    // 圖像分類器
    private class ImageClassificationAnalyzer(module: Module, listener: ResultListener?=null) : ImageAnalysis.Analyzer {

        private val mModule = module
        private val listeners = ArrayList<ResultListener>().apply { listener?.let { add(it) } }

        private fun ByteBuffer.toByteArray(): ByteArray {
            rewind()    // Rewind the buffer to zero
            val data = ByteArray(remaining())
            get(data)   // Copy the buffer into a byte array
            return data // Return the byte array
        }

        override fun analyze(imageProxy: ImageProxy) {
            if (listeners.isEmpty()) {
                imageProxy.close()
                return
            }

            val buffer = imageProxy.planes[0].buffer
            val data = buffer.toByteArray()

            // 圖像識別
            val inputTensorBuffer = Tensor.allocateFloatBuffer(3*224*224) // 輸入數據格式設置
            val inputTensor = Tensor.fromBlob(inputTensorBuffer, longArrayOf(1, 3, 224, 224)) // 轉化成tensor

            TensorImageUtils.imageYUV420CenterCropToFloatBuffer( // 加載圖片
                imageProxy.image,0, 224, 224,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
                TensorImageUtils.TORCHVISION_NORM_STD_RGB,
                inputTensorBuffer, 0)

            val outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor() // 使用模型進行圖像識別
            val scores = outputTensor.dataAsFloatArray
            var topScore = 0.0f
            var topIndex = 0
            for (index in scores.indices) { // 獲取識別結果可能性最大的
                if (topScore < scores[index]) {
                    topScore = scores[index]
                    topIndex = index
                }
            }

            // Call all listeners with new value
            listeners.forEach { it(ImageClassificationResult(topIndex, topScore)) }

            imageProxy.close()
        }
    }
}

注:到此整個pytorch模型部署就完成了,整個流程和核心代碼都在這了,其餘的可以自己擴展

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章