DecisionTreeClassifier.predict源码
时间: 2024-04-07 20:12:38 浏览: 100
DecisionTreeClassifier.predict的源码如下:
```python
def predict(self, X, check_input=True):
"""
Predict class or regression value for X.
For a classification model, the predicted class for each sample in X is
returned. For a regression model, the predicted value based on X is
returned.
Parameters
----------
X : array-like or sparse matrix of shape = [n_samples, n_features]
The input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csr_matrix``.
check_input : bool, (default=True)
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
y : array of shape = [n_samples] or [n_samples, n_outputs]
The predicted classes, or the predict values.
"""
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
proba = self.tree_.predict(X)
n_samples = X.shape[0]
if self.n_outputs_ == 1:
return self.classes_.take(np.argmax(proba, axis=1),
axis=0)
else:
class_type = self.classes_[0].__class__
predictions = np.zeros((n_samples, self.n_outputs_),
dtype=class_type)
for k in range(self.n_outputs_):
predictions[:, k] = self.classes_[k].take(
np.argmax(proba[:, k], axis=1),
axis=0)
return predictions
```
首先,它会检查分类器是否已经被拟合(通过调用`check_is_fitted`方法)。然后,它调用`_validate_X_predict`方法来验证输入数据,将其转换为`np.float32`类型,并将稀疏矩阵转换为稀疏`csr_matrix`矩阵。随后,它会调用决策树的`predict`方法,返回每个样本的预测概率。最后,它会根据分类器是否是回归模型,返回相应的预测值。如果分类器是回归模型,返回预测值;如果分类器是分类模型,返回预测的类别。
阅读全文