给我写一个DBscan聚类代码
时间: 2024-02-12 11:07:00 浏览: 60
好的,以下是 Python 代码实现 DBSCAN 聚类算法:
```python
from typing import List
import numpy as np
class DBSCAN:
def __init__(self, eps: float, min_samples: int):
self.eps = eps
self.min_samples = min_samples
def _get_neighbors(self, X: np.ndarray, p: int) -> List[int]:
neighbors = []
for i in range(X.shape[0]):
if np.linalg.norm(X[p] - X[i]) < self.eps:
neighbors.append(i)
return neighbors
def _expand_cluster(self, X: np.ndarray, labels: List[int], p: int, neighbors: List[int], cluster_id: int) -> bool:
labels[p] = cluster_id
i = 0
while i < len(neighbors):
n = neighbors[i]
if labels[n] == -1:
labels[n] = cluster_id
elif labels[n] == 0:
labels[n] = cluster_id
n_neighbors = self._get_neighbors(X, n)
if len(n_neighbors) >= self.min_samples:
neighbors += n_neighbors
i += 1
return True
def fit(self, X: np.ndarray) -> List[int]:
labels = np.zeros(X.shape[0], dtype=int)
cluster_id = 0
for i in range(X.shape[0]):
if labels[i] != 0:
continue
neighbors = self._get_neighbors(X, i)
if len(neighbors) < self.min_samples:
labels[i] = -1
else:
cluster_id += 1
self._expand_cluster(X, labels, i, neighbors, cluster_id)
return list(labels)
```
使用方式:
```python
import numpy as np
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
X, y = make_moons(n_samples=200, noise=0.1, random_state=42)
dbscan = DBSCAN(eps=0.3, min_samples=5)
labels = dbscan.fit(X)
plt.scatter(X[:, 0], X[:, 1], c=labels)
plt.show()
```
这是一个简单的示例,可以根据实际需求进行修改。
阅读全文