没有合适的资源?快使用搜索试试~ 我知道了~
首页自定义逻辑斯蒂判别式算法:鸢尾花数据集多分类实战
自定义逻辑斯蒂判别式算法:鸢尾花数据集多分类实战
27 下载量 52 浏览量
更新于2023-05-04
4
收藏 48KB PDF 举报
本文主要探讨了如何基于鸢尾花数据集实现线性判别式多分类方法。鸢尾花数据集是一个经典的多类分类问题,它包含三种不同种类的鸢尾花,每种鸢尾花都有四个特征(如花瓣长度、花瓣宽度等)。为了评估模型性能,文章将数据集划分为训练集(70%)和测试集(30%),这遵循了常见的机器学习实践,以确保模型泛化能力。 首先,文章定义了一个名为`randomdata`的函数,用于从整个数据集中随机选择70%的数据作为训练数据。这个函数通过生成0到49的随机整数,并检查是否已存在于数组中,重复此过程直到达到所需的数量,确保了数据的随机性和代表性。 接下来,文章的核心部分是训练函数`lda`,该函数接收训练数据`datas`和对应的标签`labels`作为输入。它首先计算每个属性的均值和标准差,然后对数据进行标准化处理。参数`w`是一个大小为`(K, M)`的矩阵,其中`K`是类别数(本例中为3),`M`是特征数加1(因为有一列是常数项)。使用梯度下降法和逻辑斯蒂函数,通过迭代优化`w`来最小化损失函数,使得模型能够准确地将数据分配到各个类别。 在训练过程中,函数会更新`w`并打印进度,以监控模型参数的学习过程。最后,当训练达到预设的迭代次数时,返回优化后的`w`参数。 预测函数虽然没有直接给出,但可以推断其作用是使用训练好的`w`参数,根据新的鸢尾花数据计算其类别概率,然后根据某种决策规则(如最大似然或阈值)进行分类预测。 本文提供了一个实践性的例子,展示了如何利用鸢尾花数据集和自定义的逻辑斯蒂判别式算法进行多分类,强调了数据预处理、参数优化和模型预测的重要步骤。这对于理解线性判别分析在实际问题中的应用具有很高的价值。
资源详情
资源推荐
基于鸢尾花数据集实现线性判别式多分类基于鸢尾花数据集实现线性判别式多分类
基于鸢尾花数据集实现线性判别式多分类基于鸢尾花数据集实现线性判别式多分类
本文在自己编写梯度下降的逻辑斯蒂判别式算法的基础上,对鸢尾花数据集实现多分类。
鸢尾花数据集公包含三类数据,每条数据四个特征,从中随机选取70%的数据作为训练集,30%的数据作为测试集。
主要包含三个函数:随机生成70%测试集函数、训练函数、预测函数
随机生成随机生成70%测试集函数测试集函数 randomdata
输入:无
输出:0-49之间的35个随机数
代码:
def randomdata():
array = set()
while(len(array) < 50*0.7):
n = random.randint(0,49)
if(n in array):
continue
else:
array.add(n)
return array
训练函数训练函数 lda
输入:数据datas,标签labels,
输出:参数w
代码:
def lda(datas,labels):
means=datas.mean(axis=0) #各个属性的均值
stds=datas.std(axis=0) #各个属性的标准差
N,M= datas.shape[0],datas.shape[1]+1 #N是样本数,M是参数向量的维 150 2
K=3 #k=3是类别数
data=np.ones((N,M))
data[:,1:]=(datas-means)/stds #对原始数据进行标准差归一化 第一列是常数
w=np.zeros((K,M)) #存储参数矩阵 np.dot(x,y) (特征数,种类数)
w = np.random.rand(K,M)*2-1
w = w/100
learningrate = 0.02
labels = np.array(labels)
for k in range(1000):
sita = np.dot(w,data.T)
y = np.exp(sita)
for j in range(N):
sum = 0
for i in range(len(y)):
sum += y[i,j] for i in range(len(y)):
if(sum == 0):
y[i,j] = 0
else:
y[i,j] = y[i,j]/sum
y = labels.T - y
deltaw = np.dot(y,data)
w = w + learningrate*deltaw
if(k % 100 == 0):
print(w)
return w
预测函数预测函数
输入:测试数据集datas,参数w,表现labels,
输出:预测准确率per
代码:
def predict(datas,w,labels):
means=datas.mean(axis=0) #各个属性的均值
stds=datas.std(axis=0) #各个属性的标准差
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功