基于tensorflow的猫狗图识别python代码
时间: 2023-08-09 15:12:21 浏览: 157
tensorflow实现猫狗识别
5星 · 资源好评率100%
以下是基于 TensorFlow 的猫狗图识别 Python 代码:
首先,导入所需的模块:
```python
import tensorflow as tf
import numpy as np
import cv2
```
然后,定义图像的尺寸和类别:
```python
img_size = 128
categories = ["Dog", "Cat"]
```
接着,加载模型:
```python
model = tf.keras.models.load_model("model.h5")
```
其中,`model.h5` 是已经训练好的模型文件。
最后,定义预测函数:
```python
def predict(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (img_size, img_size))
img = np.array(img).reshape(-1, img_size, img_size, 1)
prediction = model.predict(img)
result = categories[np.argmax(prediction)]
return result
```
其中,`image_path` 是待识别图像的路径,`cv2.imread` 函数用于读取图像,`cv2.resize` 函数用于调整图像大小,`np.array` 函数用于将图像数据转换为 NumPy 数组,`model.predict` 函数用于预测图像的类别,`np.argmax` 函数用于获取最大概率对应的类别,最后返回预测结果。
完整代码如下:
```python
import tensorflow as tf
import numpy as np
import cv2
img_size = 128
categories = ["Dog", "Cat"]
model = tf.keras.models.load_model("model.h5")
def predict(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (img_size, img_size))
img = np.array(img).reshape(-1, img_size, img_size, 1)
prediction = model.predict(img)
result = categories[np.argmax(prediction)]
return result
```
注意:在使用该代码之前,需要将待识别的图像保存在文件系统中,并指定其路径。
阅读全文