根据提示,在右侧编辑器补充代码,完成OvR类中的fit(self, train_datas, train_labels)与predict(self, test_datas)函数。PS:请使用代码中所提供的tiny_logistic_regression作为二分类分类器。
时间: 2023-06-02 13:03:58 浏览: 296
好的,了解。以下是代码:
```
from tiny_logistic_regression import BinaryLogisticRegression
class OvR:
def __init__(self):
self.classifiers = {}
def fit(self, train_datas, train_labels):
"""
Parameters:
train_datas: list of list, shape = (n_samples, n_features)
Training data.
train_labels: list of int, shape = (n_samples,)
Target classes.
"""
classes = list(set(train_labels))
for c in classes:
binary_labels = [1 if label==c else 0 for label in train_labels]
self.classifiers[c] = BinaryLogisticRegression()
self.classifiers[c].fit(train_datas, binary_labels)
def predict(self, test_datas):
"""
Parameters:
test_datas: list of list, shape = (n_samples, n_features)
Testing data.
Returns:
list of int, shape = (n_samples,)
Predicted classes.
"""
predictions = []
for data in test_datas:
max_prob = 0
max_class = None
for c, clf in self.classifiers.items():
prob = clf.predict_proba([data])[0][1]
if prob > max_prob:
max_prob = prob
max_class = c
predictions.append(max_class)
return predictions
```
注意:本模型只能用于二分类问题的多类别分类。如有更复杂需求,请使用更适合的模型。
阅读全文