给出python实现高斯贝叶斯分类的代码,处理鸢尾花数据集,不调用sklearn库
时间: 2023-06-02 08:03:46 浏览: 165
感谢您的问题,以下是使用Python实现高斯贝叶斯分类的代码,处理鸢尾花数据集:
```python
import csv
import math
def load_csv(filename):
dataset = []
with open(filename, 'r') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
dataset.append([float(x) for x in row])
return dataset
def split_dataset(dataset, split_ratio):
train_size = int(len(dataset) * split_ratio)
train_set = []
test_set = list(dataset)
while len(train_set) < train_size:
index = random.randrange(len(test_set))
train_set.append(test_set.pop(index))
return [train_set, test_set]
def separate_by_class(dataset):
separated = {}
for i in range(len(dataset)):
vector = dataset[i]
if vector[-1] not in separated:
separated[vector[-1]] = []
separated[vector[-1]].append(vector)
return separated
def mean(numbers):
return sum(numbers) / float(len(numbers))
def stdev(numbers):
avg = mean(numbers)
variance = sum([pow(x - avg, 2) for x in numbers]) / float(len(numbers) - 1)
return math.sqrt(variance)
def summarize(dataset):
summaries = [(mean(attribute), stdev(attribute)) for attribute in zip(*dataset)]
del summaries[-1]
return summaries
def summarize_by_class(dataset):
separated = separate_by_class(dataset)
summaries = {}
for class_value, instances in separated.items():
summaries[class_value] = summarize(instances)
return summaries
def calculate_probability(x, mean, stdev):
exponent = math.exp(-(math.pow(x - mean, 2) / (2 * math.pow(stdev, 2))))
return (1 / (math.sqrt(2 * math.pi) * stdev)) * exponent
def calculate_class_probabilities(summaries, input_vector):
probabilities = {}
for class_value, class_summaries in summaries.items():
probabilities[class_value] = 1
for i in range(len(class_summaries)):
mean, stdev = class_summaries[i]
x = input_vector[i]
probabilities[class_value] *= calculate_probability(x, mean, stdev)
return probabilities
def predict(summaries, input_vector):
probabilities = calculate_class_probabilities(summaries, input_vector)
best_label, best_prob = None, -1
for class_value, probability in probabilities.items():
if best_label is None or probability > best_prob:
best_prob = probability
best_label = class_value
return best_label
def get_predictions(summaries, test_set):
predictions = []
for i in range(len(test_set)):
result = predict(summaries, test_set[i])
predictions.append(result)
return predictions
def get_accuracy(test_set, predictions):
correct = 0
for i in range(len(test_set)):
if test_set[i][-1] == predictions[i]:
correct += 1
return (correct / float(len(test_set))) * 100.0
def main():
filename = 'iris.csv'
split_ratio = 0.67
dataset = load_csv(filename)
training_set, test_set = split_dataset(dataset, split_ratio)
print(f'Split {len(dataset)} rows into train={len(training_set)} and test={len(test_set)} rows')
summaries = summarize_by_class(training_set)
predictions = get_predictions(summaries, test_set)
accuracy = get_accuracy(test_set, predictions)
print(f'Accuracy: {accuracy}%')
if __name__ == '__main__':
main()
```
这是一个非常简单的实现,使用基于高斯分布的贝叶斯分类器来对鸢尾花数据集进行分类。它将数据集拆分为训练集和测试集,并计算训练集的各项摘要统计信息(均值和标准差)。对于每个测试向量,分类器计算它属于每个类的概率,并选择具有最高概率的类作为预测结果。最后,它计算分类器的准确性,并输出结果。注意,这个实现没有使用任何第三方库,仅使用Python内置的数学函数和csv模块来处理数据集。
阅读全文