不用库函数,从iris.csv文件中读取估计参数用的样本,每一类样本抽出前40个,分别求其均值,取第一类样本的后10个数据,按ω_1,ω_2分类,判断分类是否正确;同理,按ω_1,ω_3分类,判断分类是否正确。
时间: 2024-03-18 20:39:34 浏览: 161
这里需要用到Pandas库来读取csv文件,然后使用Python基本的运算来计算样本均值和协方差矩阵。
首先,这里提供一个读取iris.csv文件,并且抽取前40个样本的代码:
```python
import pandas as pd
# 读取iris.csv文件
data = pd.read_csv('iris.csv')
# 抽取前40个样本
sample1 = data[data['class'] == 'Iris-setosa'].iloc[:40]
sample2 = data[data['class'] == 'Iris-versicolor'].iloc[:40]
sample3 = data[data['class'] == 'Iris-virginica'].iloc[:40]
# 合并数据
samples = pd.concat([sample1, sample2, sample3], axis=0, ignore_index=True)
```
接下来,我们可以根据题目要求,计算出样本均值和协方差矩阵:
```python
import numpy as np
# 计算样本均值
mean1 = np.mean(sample1.iloc[:, :4], axis=0)
mean2 = np.mean(sample2.iloc[:, :4], axis=0)
mean3 = np.mean(sample3.iloc[:, :4], axis=0)
# 计算样本协方差矩阵
cov_mat1 = np.cov(sample1.iloc[:, :4].T)
cov_mat2 = np.cov(sample2.iloc[:, :4].T)
cov_mat3 = np.cov(sample3.iloc[:, :4].T)
```
然后,我们可以将第一类样本的后10个数据进行分类,并判断分类是否正确:
```python
# 取第一类样本的后10个数据
test_sample = sample1.iloc[40:, :4]
# 计算判别函数
g1 = np.dot(np.dot(mean1, np.linalg.inv(cov_mat1)), test_sample.T) - 0.5 * np.dot(np.dot(mean1, np.linalg.inv(cov_mat1)), mean1.T) + np.log(0.333)
g2 = np.dot(np.dot(mean2, np.linalg.inv(cov_mat2)), test_sample.T) - 0.5 * np.dot(np.dot(mean2, np.linalg.inv(cov_mat2)), mean2.T) + np.log(0.333)
# 分类
prediction = np.where(g1 > g2, 'Iris-setosa', 'Iris-versicolor')
# 判断分类是否正确
actual = sample1.iloc[40:, 4].values
accuracy = np.mean(prediction == actual)
print('分类准确率:', accuracy)
```
同理,按ω_1,ω_3分类的代码如下:
```python
# 取第一类样本的后10个数据
test_sample = sample1.iloc[40:, :4]
# 计算判别函数
g1 = np.dot(np.dot(mean1, np.linalg.inv(cov_mat1)), test_sample.T) - 0.5 * np.dot(np.dot(mean1, np.linalg.inv(cov_mat1)), mean1.T) + np.log(0.333)
g3 = np.dot(np.dot(mean3, np.linalg.inv(cov_mat3)), test_sample.T) - 0.5 * np.dot(np.dot(mean3, np.linalg.inv(cov_mat3)), mean3.T) + np.log(0.333)
# 分类
prediction = np.where(g1 > g3, 'Iris-setosa', 'Iris-virginica')
# 判断分类是否正确
actual = sample1.iloc[40:, 4].values
accuracy = np.mean(prediction == actual)
print('分类准确率:', accuracy)
```
需要注意的是,这里使用的是线性判别函数来进行分类,可能不适用于非线性分类问题。
阅读全文