0


如何将pytorch模型部署到安卓

如何将pytorch模型部署到安卓上

这篇文章演示如何将训练好的pytorch模型部署到安卓设备上。我也是刚开始学安卓,代码写的简单。

环境:

pytorch版本:1.10.0

模型转化

pytorch_android支持的模型是.pt模型,我们训练出来的模型是.pth。所以需要转化才可以用。先看官网上给的转化方式:

  1. import torch
  2. import torchvision
  3. from torch.utils.mobile_optimizer import optimize_for_mobile
  4. model = torchvision.models.mobilenet_v3_small(pretrained=True)
  5. model.eval()
  6. example = torch.rand(1,3,224,224)
  7. traced_script_module = torch.jit.trace(model, example)
  8. optimized_traced_model = optimize_for_mobile(traced_script_module)
  9. optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")

这个模型在安卓对应的包:

  1. repositories {
  2. jcenter()}
  3. dependencies {
  4. implementation 'org.pytorch:pytorch_android_lite:1.9.0'
  5. implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'}

注:pytorch_android_lite版本和转化模型用的版本要一致,不一致就会报各种错误。

目前用这种方法有点问题,我采用的另一种方法。

转化代码如下:

  1. import torch
  2. import torch.utils.data.distributed
  3. # pytorch环境中
  4. model_pth = 'model_31_0.96.pth' #模型的参数文件
  5. mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件
  6. model = torch.load(model_pth)
  7. model.eval() # 模型设为评估模式
  8. device = torch.device('cpu')
  9. model.to(device)
  10. # 1张3通道224*224的图片
  11. input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式
  12. mobile = torch.jit.trace(model, input_tensor) # 模型转化
  13. mobile.save(mobile_pt) # 保存文件

对应的包:

  1. //pytorch
  2. implementation 'org.pytorch:pytorch_android:1.10.0'
  3. implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'

定义模型文件和转化后的文件路径。

load模型。这里要注意,如果保存模型

  1. torch.save(model,'models.pth')

加载模型则是

  1. model=torch.load('models.pth')

如果保存模型是

  1. torch.save(model.state_dict(),"models.pth")

加载模型则是

  1. model.load_state_dict(torch.load('models.pth'))

定义输入数据格式。

模型转化,然后再保存模型。

安卓部署

新建项目

新建安卓项目,选择Empy Activity,然后选择Next

image-20220210142047786

然后,填写项目信息,选择安卓版本,我用的4.4,点击完成

image-20220210142213719

导入包

导入pytorch_android的包

  1. //pytorch
  2. implementation 'org.pytorch:pytorch_android:1.10.0'
  3. implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'

image-20220210142327206

如果有参数报错请参照我的完整的配置,代码如下:

  1. plugins {
  2. id 'com.android.application'}
  3. android {
  4. compileSdk 32
  5. defaultConfig {
  6. applicationId "com.example.myapplication"
  7. minSdk 21
  8. targetSdk 32
  9. versionCode 1
  10. versionName "1.0"
  11. testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"}
  12. buildTypes {
  13. release {
  14. minifyEnabled false
  15. proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'),'proguard-rules.pro'}}
  16. compileOptions {
  17. sourceCompatibility JavaVersion.VERSION_1_8
  18. targetCompatibility JavaVersion.VERSION_1_8
  19. }}
  20. dependencies {
  21. implementation 'androidx.appcompat:appcompat:1.3.0'
  22. implementation 'com.google.android.material:material:1.4.0'
  23. implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
  24. testImplementation 'junit:junit:4.13.2'
  25. androidTestImplementation 'androidx.test.ext:junit:1.1.3'
  26. androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'//pytorch
  27. implementation 'org.pytorch:pytorch_android:1.10.0'
  28. implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'}

页面文件

页面的配置如下:

  1. <?xml version="1.0" encoding="utf-8"?>
  2. <FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
  3. xmlns:tools="http://schemas.android.com/tools"
  4. android:layout_width="match_parent"
  5. android:layout_height="match_parent"
  6. tools:context=".MainActivity">
  7. <ImageView
  8. android:id="@+id/image"
  9. android:layout_width="match_parent"
  10. android:layout_height="match_parent"
  11. android:scaleType="fitCenter" />
  12. <TextView
  13. android:id="@+id/text"
  14. android:layout_width="match_parent"
  15. android:layout_height="wrap_content"
  16. android:layout_gravity="top"
  17. android:textSize="24sp"
  18. android:background="#80000000"
  19. android:textColor="@android:color/holo_red_light" />
  20. </FrameLayout>

这个页面只有两个空间,一个展示图片,一个显示文字。

image-20220210142827091

模型推理

新增assets文件夹,然后将转化的模型和待测试的图片放进去。

image-20220210143351535

新增ImageNetClasses类,这个类存放类别名字。

image-20220210143105326

代码如下:

  1. package com.example.myapplication;
  2. public classImageNetClasses{
  3. public static String[] IMAGENET_CLASSES = new String[]{"Black-grass","Charlock","Cleavers","Common Chickweed","Common wheat","Fat Hen","Loose Silky-bent","Maize","Scentless Mayweed","Shepherds Purse","Small-flowered Cranesbill","Sugar beet",};}

在MainActivity类中,增加模型推理的逻辑。完成代码如下:

  1. package com.example.myapplication;
  2. import android.content.Context;
  3. import android.graphics.Bitmap;
  4. import android.graphics.BitmapFactory;
  5. import android.os.Bundle;
  6. import android.util.Log;
  7. import android.widget.ImageView;
  8. import android.widget.TextView;
  9. import org.pytorch.IValue;
  10. import org.pytorch.Module;
  11. import org.pytorch.Tensor;
  12. import org.pytorch.torchvision.TensorImageUtils;
  13. import org.pytorch.MemoryFormat;
  14. import java.io.File;
  15. import java.io.FileOutputStream;
  16. import java.io.IOException;
  17. import java.io.InputStream;
  18. import java.io.OutputStream;
  19. import androidx.appcompat.app.AppCompatActivity;
  20. public class MainActivity extends AppCompatActivity {
  21. @Override
  22. protected void onCreate(Bundle savedInstanceState) {
  23. super.onCreate(savedInstanceState);
  24. setContentView(R.layout.activity_main);
  25. Bitmap bitmap = null;
  26. Module module = null;
  27. try {
  28. // creating bitmap from packaged into app android asset 'image.jpg',
  29. // app/src/main/assets/image.jpg
  30. bitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));
  31. // loading serialized torchscript module from packaged into app android asset model.pt,
  32. // app/src/model/assets/model.pt
  33. module = Module.load(assetFilePath(this, "models.pt"));
  34. } catch (IOException e) {
  35. Log.e("PytorchHelloWorld", "Error reading assets", e);
  36. finish();
  37. }
  38. // showing image on UI
  39. ImageView imageView = findViewById(R.id.image);
  40. imageView.setImageBitmap(bitmap);
  41. // preparing input tensor
  42. final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
  43. TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
  44. // running the model
  45. final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
  46. // getting tensor content as java array of floats
  47. final float[] scores = outputTensor.getDataAsFloatArray();
  48. // searching for the index with maximum score
  49. float maxScore = -Float.MAX_VALUE;
  50. int maxScoreIdx = -1;
  51. for (int i = 0; i < scores.length; i++) {
  52. if (scores[i] > maxScore) {
  53. maxScore = scores[i];
  54. maxScoreIdx = i;
  55. }
  56. }
  57. System.out.println(maxScoreIdx);
  58. String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
  59. // showing className on UI
  60. TextView textView = findViewById(R.id.text);
  61. textView.setText(className);
  62. }
  63. /**
  64. * Copies specified asset to the file in /files app directory and returns this file absolute path.
  65. *
  66. * @return absolute file path
  67. */
  68. public static String assetFilePath(Context context, String assetName) throws IOException {
  69. File file = new File(context.getFilesDir(), assetName);
  70. if (file.exists() && file.length() > 0) {
  71. return file.getAbsolutePath();
  72. }
  73. try (InputStream is = context.getAssets().open(assetName)) {
  74. try (OutputStream os = new FileOutputStream(file)) {
  75. byte[] buffer = new byte[4 * 1024];
  76. int read;
  77. while ((read = is.read(buffer)) != -1) {
  78. os.write(buffer, 0, read);
  79. }
  80. os.flush();
  81. }
  82. return file.getAbsolutePath();
  83. }
  84. }
  85. }

然后运行。

image-20220210143529635


本文转载自: https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/122860445
版权归原作者 AI浩 所有, 如有侵权,请联系我们删除。

“如何将pytorch模型部署到安卓”的评论:

还没有评论