给我一个基于鸢尾花数据集的logisitic源码
时间: 2023-04-08 22:01:58 浏览: 103
以下是基于鸢尾花数据集的logistic回归源码:
```
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 定义sigmoid函数
def sigmoid(z):
return 1 / (1 + np.exp(-z))
# 定义损失函数
def loss(h, y):
return (-y * np.log(h) - (1 - y) * np.log(1 - h)).mean()
# 定义训练函数
def train(X, y, learning_rate, num_iters):
# 初始化参数
theta = np.zeros(X.shape[1])
# 迭代更新参数
for i in range(num_iters):
z = np.dot(X, theta)
h = sigmoid(z)
gradient = np.dot(X.T, (h - y)) / y.size
theta -= learning_rate * gradient
# 每迭代100次输出一次损失函数值
if i % 100 == 0:
z = np.dot(X, theta)
h = sigmoid(z)
print('loss: {}'.format(loss(h, y)))
return theta
# 训练模型
theta = train(X_train, y_train, learning_rate=0.1, num_iters=1000)
# 预测测试集
z = np.dot(X_test, theta)
h = sigmoid(z)
y_pred = np.round(h)
# 计算准确率
accuracy = (y_pred == y_test).mean()
print('accuracy: {}'.format(accuracy))
```
希望对你有帮助!
阅读全文