lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite详细的部署在安卓上的代码,测试一张图片
时间: 2024-05-13 09:15:35 浏览: 177
基于Android运行人脸表情识别的tflite模型源码+项目说明.zip
部署 MoveNet SinglePose Lightning 模型在 Android 上需要以下步骤:
1. 将模型转换为 Tensorflow Lite 格式并下载模型文件 `lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite`。
2. 在 Android Studio 中创建一个新项目,并添加以下依赖项:
```
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0'
```
3. 将模型文件放入 Android 项目的 `assets` 目录下。
4. 加载模型文件并初始化解释器:
```java
// 加载模型文件
AssetFileDescriptor assetFileDescriptor = getAssets().openFd("lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite");
FileInputStream inputStream = new FileInputStream(assetFileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = assetFileDescriptor.getStartOffset();
long declaredLength = assetFileDescriptor.getDeclaredLength();
MappedByteBuffer modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
// 初始化解释器
Interpreter.Options options = new Interpreter.Options();
GpuDelegate delegate = new GpuDelegate();
options.addDelegate(delegate);
Interpreter interpreter = new Interpreter(modelBuffer, options);
```
5. 加载图片并预处理,将其转换为模型所需的输入格式:
```java
// 加载图片
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("test_image.png"));
// 缩放图片
int inputSize = 256;
Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);
// 将图片转换为 Tensor
float[][][][] inputTensor = new float[1][inputSize][inputSize][3];
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < inputSize; j++) {
int pixel = resizedBitmap.getPixel(j, i);
inputTensor[0][i][j][0] = (float) ((pixel >> 16) & 0xFF) / 255.0f;
inputTensor[0][i][j][1] = (float) ((pixel >> 8) & 0xFF) / 255.0f;
inputTensor[0][i][j][2] = (float) (pixel & 0xFF) / 255.0f;
}
}
// 创建输入 Tensor
int[] inputShape = interpreter.getInputTensor(0).shape();
DataType inputDataType = interpreter.getInputTensor(0).dataType();
Tensor inputTensorBuffer = Tensor.fromBuffer(inputTensor, inputDataType, inputShape);
```
6. 运行模型并获取输出:
```java
// 运行模型
Map<Integer, Object> outputMap = new HashMap<>();
int[] outputShape = interpreter.getOutputTensor(0).shape();
DataType outputDataType = interpreter.getOutputTensor(0).dataType();
Tensor outputTensorBuffer = Tensor.allocate(outputDataType, outputShape);
outputMap.put(0, outputTensorBuffer);
interpreter.runForMultipleInputsOutputs(new Object[]{inputTensorBuffer}, outputMap);
// 获取关键点坐标
float[][][] outputTensor = new float[1][17][3];
outputTensorBuffer.copyTo(outputTensor);
float[][] keypoints = outputTensor[0];
```
7. 使用获取的关键点坐标绘制姿势:
```java
// 绘制姿势
Canvas canvas = new Canvas(bitmap);
Paint paint = new Paint();
paint.setColor(Color.RED);
paint.setStrokeWidth(10);
for (int i = 0; i < keypoints.length; i++) {
float x = keypoints[i][1] * bitmap.getWidth();
float y = keypoints[i][0] * bitmap.getHeight();
canvas.drawPoint(x, y, paint);
}
imageView.setImageBitmap(bitmap);
```
完整代码如下:
```java
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import androidx.appcompat.app.AppCompatActivity;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
import java.io.FileInputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MoveNetDemo";
private ImageView imageView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
imageView = findViewById(R.id.imageView);
try {
// 加载模型文件
AssetFileDescriptor assetFileDescriptor = getAssets().openFd("lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite");
FileInputStream inputStream = new FileInputStream(assetFileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = assetFileDescriptor.getStartOffset();
long declaredLength = assetFileDescriptor.getDeclaredLength();
MappedByteBuffer modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
// 初始化解释器
Interpreter.Options options = new Interpreter.Options();
GpuDelegate delegate = new GpuDelegate();
options.addDelegate(delegate);
Interpreter interpreter = new Interpreter(modelBuffer, options);
// 加载图片
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("test_image.png"));
// 缩放图片
int inputSize = 256;
Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);
// 将图片转换为 Tensor
float[][][][] inputTensor = new float[1][inputSize][inputSize][3];
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < inputSize; j++) {
int pixel = resizedBitmap.getPixel(j, i);
inputTensor[0][i][j][0] = (float) ((pixel >> 16) & 0xFF) / 255.0f;
inputTensor[0][i][j][1] = (float) ((pixel >> 8) & 0xFF) / 255.0f;
inputTensor[0][i][j][2] = (float) (pixel & 0xFF) / 255.0f;
}
}
// 创建输入 Tensor
int[] inputShape = interpreter.getInputTensor(0).shape();
DataType inputDataType = interpreter.getInputTensor(0).dataType();
Tensor inputTensorBuffer = Tensor.fromBuffer(inputTensor, inputDataType, inputShape);
// 运行模型
Map<Integer, Object> outputMap = new HashMap<>();
int[] outputShape = interpreter.getOutputTensor(0).shape();
DataType outputDataType = interpreter.getOutputTensor(0).dataType();
Tensor outputTensorBuffer = Tensor.allocate(outputDataType, outputShape);
outputMap.put(0, outputTensorBuffer);
interpreter.runForMultipleInputsOutputs(new Object[]{inputTensorBuffer}, outputMap);
// 获取关键点坐标
float[][][] outputTensor = new float[1][17][3];
outputTensorBuffer.copyTo(outputTensor);
float[][] keypoints = outputTensor[0];
// 绘制姿势
Canvas canvas = new Canvas(bitmap);
Paint paint = new Paint();
paint.setColor(Color.RED);
paint.setStrokeWidth(10);
for (int i = 0; i < keypoints.length; i++) {
float x = keypoints[i][1] * bitmap.getWidth();
float y = keypoints[i][0] * bitmap.getHeight();
canvas.drawPoint(x, y, paint);
}
imageView.setImageBitmap(bitmap);
} catch (Exception e) {
Log.e(TAG, "Error", e);
}
}
}
```
在 `assets` 目录下添加一张名为 `test_image.png` 的测试图片,运行应用程序即可看到绘制的姿势。
阅读全文