编写一段基于Java的图像识别代码
时间: 2023-12-14 07:39:07 浏览: 106
以下是一个简单的基于Java的图像识别代码示例,用于识别一张图片中是否存在猫的面部特征:
```java
import java.awt.image.BufferedImage;
import java.io.File;
import javax.imageio.ImageIO;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
public class ImageRecognition {
public static void main(String[] args) throws Exception {
// Load the pre-trained model graph
byte[] graphDef = readAllBytesOrExit(Paths.get("model.pb"));
Graph graph = new Graph();
graph.importGraphDef(graphDef);
// Load the input image
BufferedImage image = ImageIO.read(new File("input.jpg"));
// Preprocess the image
float[][][][] input = preprocess(image);
// Create a TensorFlow session
try (Session session = new Session(graph)) {
// Feed the image into the graph
Tensor<Float> inputTensor = Tensor.create(input);
Tensor<Float> outputTensor = session
.runner()
.feed("input", inputTensor)
.fetch("output")
.run()
.get(0)
.expect(Float.class);
// Get the prediction result
float[] result = new float[2];
outputTensor.copyTo(result);
float catScore = result[0];
float nonCatScore = result[1];
// Print the prediction result
if (catScore > nonCatScore) {
System.out.println("The image contains a cat!");
} else {
System.out.println("The image does not contain a cat.");
}
}
}
private static byte[] readAllBytesOrExit(Path path) throws IOException {
byte[] bytes = Files.readAllBytes(path);
return bytes;
}
private static float[][][][] preprocess(BufferedImage image) {
float[][][][] input = new float[1][224][224][3];
for (int y = 0; y < 224; y++) {
for (int x = 0; x < 224; x++) {
int rgb = image.getRGB(x * image.getWidth() / 224, y * image.getHeight() / 224);
input[0][y][x][0] = ((rgb >> 16) & 0xFF) / 255.0f;
input[0][y][x][1] = ((rgb >> 8) & 0xFF) / 255.0f;
input[0][y][x][2] = (rgb & 0xFF) / 255.0f;
}
}
return input;
}
}
```
该代码使用了 TensorFlow Java API,首先读取了一个预训练好的模型图(即一个.pb文件),然后加载输入图片并进行预处理,最后将其输入至模型中进行预测。在本例中,模型会返回一个包含两个元素的浮点数数组,分别表示输入图片中包含猫的概率和不包含猫的概率。最后,根据输出结果打印出相应的预测结果。
阅读全文