tensorflow支持向量机鸢尾花分类
时间: 2024-11-13 14:15:18 浏览: 23
TensorFlow是一个广泛使用的开源机器学习框架,它不仅支持深度学习,也包括传统的机器学习算法。支持向量机(SVM)是一种监督学习模型,在许多分类任务中非常有效,尤其是对于小规模高维数据集,如鸢尾花(Iris)数据集。
在TensorFlow中,你可以使用tf.estimator API或者直接操作底层的张量来进行SVM分类。以下是使用tf.estimator库简化版的例子:
```python
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载鸢尾花数据
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 定义特征列和SVM Estimator
feature_columns = [tf.feature_column.numeric_column('features', shape=(4,))]
estimator = tf.estimator.LinearClassifier(feature_columns=feature_columns, model_dir='svm_model')
# 训练模型
estimator.train(input_fn=lambda: tf.estimator.inputs.numpy_input_fn(x={"features": X_train}, y=y_train, num_epochs=None, shuffle=True))
# 预测
predictions = estimator.predict(input_fn=lambda: tf.estimator.inputs.numpy_input_fn(x={"features": X_test}, num_epochs=1, shuffle=False))
# 评估性能
metrics = estimator.evaluate(input_fn=lambda: tf.estimator.inputs.numpy_input_fn(x={"features": X_test}, y=y_test, num_epochs=1, shuffle=False))
```
阅读全文