最後效果:
準備:
- pytorch1.4(pytorch環境至少要在1.3以上,當前最新版本1.4)
- 已經訓練好的pytorch模型
- Jetpack組件:CameraX(這個用來調用相機的)
如有需要,可以先看看我這兩篇博文:
如果pytorch環境不滿足,進行pytorch環境升級:win10+pytorch1.4+cuda10.1安裝:從顯卡驅動開始
Jetpack組件:CameraX,使用前一定要先了解:Jetpack CameraX實踐,預覽(preview)及分析(analysis)
模型轉化
# pytorch環境中
model_pth = os.path.join(MODEL_PATH, 'resnet18.pth') # resnet18模型的參數文件
mobile_pt = os.path.join(MODEL_PATH, 'resnet18.pt') # 將resnet18模型保存爲Android可以調用的文件
model = make_model('resnet18') # 搭建網絡
model.load_state_dict(torch.load(model_pth)) # 加載參數
model.eval() # 模型設爲評估模式
# 1張3通道224*224的圖片
input_tensor = torch.rand(1, 3, 224, 224) # 設定輸入數據格式
mobile = torch.jit.trace(model, input_tensor) # 模型轉化
mobile.save(mobile_pt) # 保存文件
注:這樣就完成了模型的轉化,得到resnet18.pt文件
Android 設置CameraX:實現預覽
添加依賴:
// CameraX core library using the camera2 implementation
def camerax_version = "1.0.0-beta01"
implementation "androidx.camera:camera-camera2:${camerax_version}"
implementation "androidx.camera:camera-view:1.0.0-alpha08"
implementation "androidx.camera:camera-extensions:1.0.0-alpha08"
implementation "androidx.camera:camera-lifecycle:${camerax_version}"
//pytorch
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
申請相機權限及跳轉:
在AndroidManifest.xml中添加權限:<uses-permission android:name="android.permission.CAMERA" />
,然後跳轉CameraX頁面之前進行動態權限申請(也可以將動態權限申請放在CameraXFragment中,調用相機時再申請):
package com.example.gca.leftFragment
import android.Manifest
......
import kotlinx.android.synthetic.main.left_fragment.*
private const val REQUEST_CODE_PERMISSIONS = 10 // 權限標識符
private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA) // 相機權限
class LeftFragment : Fragment() {
override fun onCreateView(
inflater: LayoutInflater, container: ViewGroup?,
savedInstanceState: Bundle?
): View? {
return inflater.inflate(R.layout.left_fragment, container, false)
}
override fun onActivityCreated(savedInstanceState: Bundle?) {
super.onActivityCreated(savedInstanceState)
// 去到CameraXFragment頁面
buttonCameraX.setOnClickListener {
// 檢查相機權限
if (allPermissionsGranted()) {
// 這裏是用導航組件(Navigation)進行跳轉的
Navigation.findNavController(it).navigate(R.id.action_leftFragment_to_cameraXFragment)
} else {
requestPermissions(REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS)
}
}
}
// 請求權限結果回調
override fun onRequestPermissionsResult(
requestCode: Int, permissions: Array<String>, grantResults: IntArray
) {
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
// 權限通過,進行跳轉
Navigation.findNavController(requireView()).navigate(R.id.action_leftFragment_to_cameraXFragment)
} else {
Toast.makeText(
requireContext(),
"Permissions not granted by the user.",
Toast.LENGTH_SHORT
).show()
}
}
}
// 檢查權限
private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all {
ContextCompat.checkSelfPermission(
requireContext(), it
) == PackageManager.PERMISSION_GRANTED
}
}
新建一個fragment和佈局文件(用來放置相機的),佈局如下(fragment_camera_x.xml):
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".leftFragment.cameraXFragment.CameraXFragment">
<androidx.camera.view.PreviewView
android:id="@+id/previewView"
android:layout_width="wrap_content"
android:layout_height="0dp"
android:layout_marginBottom="16dp"
app:layout_constraintBottom_toTopOf="@+id/textView2"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.0"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<TextView
android:id="@+id/textView2"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="32dp"
android:text="TextView"
android:textSize="30sp"
app:layout_constraintBottom_toTopOf="@+id/textView3"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent" />
<TextView
android:id="@+id/textView3"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="32dp"
android:text="TextView"
android:textSize="30sp"
app:layout_constraintBottom_toTopOf="@+id/textView4"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent" />
<TextView
android:id="@+id/textView4"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="32dp"
android:text="TextView"
android:textSize="30sp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>
fragment設置(CameraXFragment.kt):
package com.example.gca.leftFragment.cameraXFragment
import android.os.Bundle
......
import java.util.concurrent.Executors
class CameraXFragment : Fragment(), CameraXConfig.Provider {
override fun getCameraXConfig(): CameraXConfig {
return Camera2Config.defaultConfig()
}
private lateinit var cameraProviderFuture: ListenableFuture<ProcessCameraProvider> // 相機的控制者
private lateinit var imagePreview: Preview // 圖像預覽
private lateinit var cameraPreviewView: PreviewView // 顯示相機的控件
override fun onCreateView(
inflater: LayoutInflater, container: ViewGroup?,
savedInstanceState: Bundle?
): View? {
// Inflate the layout for this fragment
return inflater.inflate(R.layout.fragment_camera_x, container, false)
}
override fun onActivityCreated(savedInstanceState: Bundle?) {
super.onActivityCreated(savedInstanceState)
cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相機控制權
cameraPreviewView = previewView // 顯示相機控件
// 加載相機
cameraPreviewView.post { startCamera() }
}
private fun startCamera() {
// 預覽
imagePreview = Preview.Builder().apply {
setTargetAspectRatio(AspectRatio.RATIO_16_9)
setTargetRotation(previewView.display.rotation)
}.build()
imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)
// 綁定
val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()
cameraProviderFuture.addListener(Runnable {
val cameraProvider = cameraProviderFuture.get()
cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview)
}, ContextCompat.getMainExecutor(requireContext()))
}
}
注:到這一步,可以運行項目,已經可以調用相機進行預覽了,如果不行,參考Jetpack CameraX實踐,預覽(preview)及分析(analysis),應該是哪裏漏掉了什麼
pytorch模型部署
添加資源,將我們轉化的模型resnet18.pt複製到assets文件夾下(如果你沒有assets文件夾,參考:https://blog.csdn.net/y_dd6011)
添加兩個常量:
const val MODEL_NAME = "resnet18.pt" // 神經網絡
val IMAGE_CLASSIFICATION = arrayOf( // 這個就是你的神經網絡能夠識別的圖片種類數目
"tench, Tinca tinca",
......
"goldfish, Carassius auratus",
)
新建一個kotlin類(Unit.kt):(用來獲取神經網絡的絕對地址)
package com.example.gca.unit
import android.content.Context
import android.util.Log
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
object Unit {
fun assetFilePath(context: Context, assetName: String): String? {
val file = File(context.filesDir, assetName)
try {
context.assets.open(assetName).use { `is` ->
FileOutputStream(file).use { os ->
val buffer = ByteArray(4 * 1024)
while (true) {
val length = `is`.read(buffer)
if (length <= 0)
break
os.write(buffer, 0, length)
}
os.flush()
os.close()
}
return file.absolutePath
}
} catch (e: IOException) {
Log.e("pytorch", "Error process asset $assetName to file path")
}
return null
}
}
再新建一個kotlin類(ImageClassificationResult.kt):(圖像分析之後的回調結果,這裏封裝成類)
package com.example.gca.unit
import com.example.gca.IMAGE_CLASSIFICATION
class ImageClassificationResult(private val index: Int, private val value: Float) {
fun getImageClassification() = IMAGE_CLASSIFICATION[index]
fun getGarbageIndex() = index
fun getGarbageValue() = value
}
最後一步,給相機添加圖像分析器(CameraXFragment.kt完整代碼如下):
package com.example.gca.leftFragment.cameraXFragment
import android.os.Bundle
import android.util.Log
import android.util.Size
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import androidx.camera.camera2.Camera2Config
import androidx.camera.core.*
import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView
import androidx.core.content.ContextCompat
import androidx.fragment.app.Fragment
import com.example.gca.MODEL_NAME
import com.example.gca.R
import com.example.gca.unit.ImageClassificationResult
import com.example.gca.unit.Unit.assetFilePath
import com.google.common.util.concurrent.ListenableFuture
import kotlinx.android.synthetic.main.fragment_camera_x.*
import kotlinx.coroutines.MainScope
import kotlinx.coroutines.launch
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.Tensor
import org.pytorch.torchvision.TensorImageUtils
import java.nio.ByteBuffer
import java.util.concurrent.Executors
typealias ResultListener = (result: ImageClassificationResult) -> Unit // 圖像分析器的返回結果類型,typealias 是取別名
class CameraXFragment : Fragment(), CameraXConfig.Provider {
override fun getCameraXConfig(): CameraXConfig {
return Camera2Config.defaultConfig()
}
private lateinit var cameraProviderFuture: ListenableFuture<ProcessCameraProvider> // 相機的控制者
private lateinit var imagePreview: Preview // 圖像預覽
private lateinit var imageAnalysis: ImageAnalysis // 圖像分析
private val executor = Executors.newSingleThreadExecutor() // 後臺線程
private lateinit var cameraPreviewView: PreviewView // 顯示相機的控件
private lateinit var module: Module // 模型
override fun onCreateView(
inflater: LayoutInflater, container: ViewGroup?,
savedInstanceState: Bundle?
): View? {
// Inflate the layout for this fragment
return inflater.inflate(R.layout.fragment_camera_x, container, false)
}
override fun onActivityCreated(savedInstanceState: Bundle?) {
super.onActivityCreated(savedInstanceState)
cameraProviderFuture = ProcessCameraProvider.getInstance(requireContext()) // 相機控制權
cameraPreviewView = previewView // 顯示相機控件
// 加載圖片識別模型
try {
val modulePath = assetFilePath(requireContext(), MODEL_NAME)
module = Module.load(modulePath)
} catch (e: Exception) {
Log.e(CameraXFragment::class.java.simpleName, e.toString())
}
// 加載相機
cameraPreviewView.post { startCamera() }
}
private fun startCamera() {
// 預覽
imagePreview = Preview.Builder().apply {
setTargetAspectRatio(AspectRatio.RATIO_16_9)
setTargetRotation(previewView.display.rotation)
}.build()
imagePreview.setSurfaceProvider(previewView.previewSurfaceProvider)
// 分析
imageAnalysis = ImageAnalysis.Builder().apply {
setImageQueueDepth(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
setTargetResolution(Size(224, 224))
}.build()
imageAnalysis.setAnalyzer(executor, ImageClassificationAnalyzer(module) {
MainScope().launch {
textView2.text = it.getImageClassification()
textView3.text = it.getGarbageIndex().toString()
textView4.text = it.getGarbageValue().toString()
}
Log.v(CameraXFragment::class.java.simpleName, it.toString())
})
// 綁定
val cameraSelector = CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build()
cameraProviderFuture.addListener(Runnable {
val cameraProvider = cameraProviderFuture.get()
cameraProvider.bindToLifecycle(this, cameraSelector, imagePreview, imageAnalysis)
}, ContextCompat.getMainExecutor(requireContext()))
}
// 圖像分類器
private class ImageClassificationAnalyzer(module: Module, listener: ResultListener?=null) : ImageAnalysis.Analyzer {
private val mModule = module
private val listeners = ArrayList<ResultListener>().apply { listener?.let { add(it) } }
private fun ByteBuffer.toByteArray(): ByteArray {
rewind() // Rewind the buffer to zero
val data = ByteArray(remaining())
get(data) // Copy the buffer into a byte array
return data // Return the byte array
}
override fun analyze(imageProxy: ImageProxy) {
if (listeners.isEmpty()) {
imageProxy.close()
return
}
val buffer = imageProxy.planes[0].buffer
val data = buffer.toByteArray()
// 圖像識別
val inputTensorBuffer = Tensor.allocateFloatBuffer(3*224*224) // 輸入數據格式設置
val inputTensor = Tensor.fromBlob(inputTensorBuffer, longArrayOf(1, 3, 224, 224)) // 轉化成tensor
TensorImageUtils.imageYUV420CenterCropToFloatBuffer( // 加載圖片
imageProxy.image,0, 224, 224,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
inputTensorBuffer, 0)
val outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor() // 使用模型進行圖像識別
val scores = outputTensor.dataAsFloatArray
var topScore = 0.0f
var topIndex = 0
for (index in scores.indices) { // 獲取識別結果可能性最大的
if (topScore < scores[index]) {
topScore = scores[index]
topIndex = index
}
}
// Call all listeners with new value
listeners.forEach { it(ImageClassificationResult(topIndex, topScore)) }
imageProxy.close()
}
}
}
注:到此整個pytorch模型部署就完成了,整個流程和核心代碼都在這了,其餘的可以自己擴展