写一个Android 基于tensorflow自动检测并拍照的demo
时间: 2023-07-31 16:09:23 浏览: 57
这里是一个基于TensorFlow Lite和Camera2 API的Android自动检测并拍照的Demo。
首先,在gradle文件中添加以下依赖:
```
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.5.0'
```
接下来,创建一个Camera2的预览类CameraPreview,用于预览相机画面,并在其中初始化TensorFlow Lite模型。
```
public class CameraPreview extends TextureView implements TextureView.SurfaceTextureListener, ImageReader.OnImageAvailableListener {
private static final String MODEL_PATH = "model.tflite";
private static final String LABELS_PATH = "labels.txt";
private CameraDevice cameraDevice;
private CameraCaptureSession cameraCaptureSession;
private CaptureRequest.Builder captureRequestBuilder;
private HandlerThread backgroundThread;
private Handler backgroundHandler;
private ImageReader imageReader;
private Interpreter interpreter;
private List<String> labels;
public CameraPreview(Context context) {
super(context);
setSurfaceTextureListener(this);
initModel();
}
private void initModel() {
try {
// 加载模型
interpreter = new Interpreter(loadModelFile(), new Interpreter.Options());
// 加载标签
labels = loadLabelList();
} catch (IOException e) {
e.printStackTrace();
}
}
private MappedByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor fileDescriptor = getContext().getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
private List<String> loadLabelList() throws IOException {
List<String> labelList = new ArrayList<>();
BufferedReader reader = new BufferedReader(new InputStreamReader(getContext().getAssets().open(LABELS_PATH)));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
@Override
public void onSurfaceTextureAvailable(SurfaceTexture surfaceTexture, int width, int height) {
openCamera();
}
private void openCamera() {
CameraManager cameraManager = (CameraManager) getContext().getSystemService(Context.CAMERA_SERVICE);
try {
String cameraId = cameraManager.getCameraIdList()[0];
CameraCharacteristics cameraCharacteristics = cameraManager.getCameraCharacteristics(cameraId);
Size[] outputSizes = cameraCharacteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP).getOutputSizes(ImageFormat.JPEG);
imageReader = ImageReader.newInstance(outputSizes[0].getWidth(), outputSizes[0].getHeight(), ImageFormat.JPEG, 1);
imageReader.setOnImageAvailableListener(this, backgroundHandler);
if (ActivityCompat.checkSelfPermission(getContext(), Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
return;
}
cameraManager.openCamera(cameraId, new CameraDevice.StateCallback() {
@Override
public void onOpened(CameraDevice cameraDevice) {
CameraPreview.this.cameraDevice = cameraDevice;
createCameraPreviewSession();
}
@Override
public void onDisconnected(CameraDevice cameraDevice) {
cameraDevice.close();
CameraPreview.this.cameraDevice = null;
}
@Override
public void onError(CameraDevice cameraDevice, int error) {
cameraDevice.close();
CameraPreview.this.cameraDevice = null;
}
}, backgroundHandler);
} catch (CameraAccessException e) {
e.printStackTrace();
}
}
private void createCameraPreviewSession() {
SurfaceTexture surfaceTexture = getSurfaceTexture();
surfaceTexture.setDefaultBufferSize(1920, 1080);
Surface previewSurface = new Surface(surfaceTexture);
Surface readerSurface = imageReader.getSurface();
try {
captureRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
captureRequestBuilder.addTarget(previewSurface);
captureRequestBuilder.addTarget(readerSurface);
cameraDevice.createCaptureSession(Arrays.asList(previewSurface, readerSurface), new CameraCaptureSession.StateCallback() {
@Override
public void onConfigured(CameraCaptureSession cameraCaptureSession) {
CameraPreview.this.cameraCaptureSession = cameraCaptureSession;
updatePreview();
}
@Override
public void onConfigureFailed(CameraCaptureSession cameraCaptureSession) {
}
}, backgroundHandler);
} catch (CameraAccessException e) {
e.printStackTrace();
}
}
private void updatePreview() {
if (cameraDevice == null) {
return;
}
captureRequestBuilder.set(CaptureRequest.CONTROL_MODE, CameraMetadata.CONTROL_MODE_AUTO);
try {
cameraCaptureSession.setRepeatingRequest(captureRequestBuilder.build(), null, backgroundHandler);
} catch (CameraAccessException e) {
e.printStackTrace();
}
}
@Override
public void onSurfaceTextureSizeChanged(SurfaceTexture surfaceTexture, int width, int height) {
}
@Override
public boolean onSurfaceTextureDestroyed(SurfaceTexture surfaceTexture) {
closeCamera();
return true;
}
private void closeCamera() {
if (cameraDevice != null) {
cameraDevice.close();
cameraDevice = null;
}
if (cameraCaptureSession != null) {
cameraCaptureSession.close();
cameraCaptureSession = null;
}
if (imageReader != null) {
imageReader.close();
imageReader = null;
}
}
@Override
public void onSurfaceTextureUpdated(SurfaceTexture surfaceTexture) {
}
@Override
public void onImageAvailable(ImageReader reader) {
Image image = reader.acquireLatestImage();
Bitmap bitmap = getBitmap(image);
image.close();
// 在子线程中进行模型推理
new Thread(() -> {
String result = recognize(bitmap);
if (result.equals("cat")) {
// 拍照
takePicture();
}
}).start();
}
private Bitmap getBitmap(Image image) {
ByteBuffer buffer = image.getPlanes()[0].getBuffer();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return BitmapFactory.decodeByteArray(bytes, 0, bytes.length);
}
private String recognize(Bitmap bitmap) {
Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
TensorImage inputImage = new TensorImage(DataType.FLOAT32);
inputImage.load(resizedBitmap);
TensorBuffer outputBuffer = TensorBuffer.createFixedSize(new int[]{1, labels.size()}, DataType.FLOAT32);
interpreter.run(inputImage.getBuffer(), outputBuffer.getBuffer());
float[] results = outputBuffer.getFloatArray();
int index = getMaxIndex(results);
return labels.get(index);
}
private int getMaxIndex(float[] array) {
int maxIndex = 0;
float max = array[maxIndex];
for (int i = 1; i < array.length; i++) {
if (array[i] > max) {
max = array[i];
maxIndex = i;
}
}
return maxIndex;
}
private void takePicture() {
try {
CaptureRequest.Builder builder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_STILL_CAPTURE);
builder.addTarget(imageReader.getSurface());
builder.set(CaptureRequest.CONTROL_MODE, CameraMetadata.CONTROL_MODE_AUTO);
builder.set(CaptureRequest.JPEG_ORIENTATION, getOrientation());
cameraCaptureSession.stopRepeating();
cameraCaptureSession.abortCaptures();
cameraCaptureSession.capture(builder.build(), null, null);
} catch (CameraAccessException e) {
e.printStackTrace();
}
}
private int getOrientation() {
int rotation = ((Activity) getContext()).getWindowManager().getDefaultDisplay().getRotation();
int sensorOrientation = cameraDevice.getCameraCharacteristics(CameraCharacteristics.SENSOR_ORIENTATION);
return (rotation + sensorOrientation + 270) % 360;
}
public void startBackgroundThread() {
backgroundThread = new HandlerThread("Camera Background");
backgroundThread.start();
backgroundHandler = new Handler(backgroundThread.getLooper());
}
public void stopBackgroundThread() {
backgroundThread.quitSafely();
try {
backgroundThread.join();
backgroundThread = null;
backgroundHandler = null;
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
```
然后,在Activity中使用CameraPreview类,并在onResume和onPause方法中开始和停止后台线程。
```
public class MainActivity extends AppCompatActivity {
private CameraPreview cameraPreview;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
cameraPreview = new CameraPreview(this);
FrameLayout previewLayout = findViewById(R.id.preview_layout);
previewLayout.addView(cameraPreview);
}
@Override
protected void onResume() {
super.onResume();
cameraPreview.startBackgroundThread();
}
@Override
protected void onPause() {
cameraPreview.stopBackgroundThread();
super.onPause();
}
}
```
注意,这个Demo还需要一张名为"labels.txt"的标签文件和一个名为"model.tflite"的TensorFlow Lite模型文件,应该将它们放在assets目录下。