cnn卷积神经网络图片分类代码
时间: 2024-12-25 17:14:23 浏览: 2
CNN(Convolutional Neural Network,卷积神经网络)在图像分类任务中的应用非常广泛。这里简述一下基本的Python代码结构,通常使用深度学习库如TensorFlow或PyTorch来实现。
首先,需要安装必要的库:
```bash
pip install tensorflow numpy pillow
```
然后,我们可以用Keras(基于TensorFlow或PyTorch的高级API)来编写一个简单的卷积神经网络模型:
```python
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 定义模型
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 输入层,假设img_height和img_width是图片尺寸
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(), # 将二维特征展平成一维
Dense(128, activation='relu'), # 全连接层
Dense(num_classes, activation='softmax') # 输出层,num_classes是类别数,激活函数通常是softmax用于多分类
])
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型(假设有train_data和train_labels)
model.fit(train_data, train_labels, epochs=epochs, validation_split=0.2)
# 测试模型
test_loss, test_acc = model.evaluate(test_data, test_labels)
```
阅读全文