som神经网络python源代码
时间: 2023-09-05 10:02:21 浏览: 200
som神经网络,即自组织映射(Self-Organizing Map),是一种无监督学习的神经网络算法,用于聚类和数据降维。下面是一个基于Python的简单som神经网络的代码示例:
```python
import numpy as np
class SOM:
def __init__(self, input_dim, output_dim, learning_rate=0.2, num_iterations=100):
self.input_dim = input_dim
self.output_dim = output_dim
self.learning_rate = learning_rate
self.num_iterations = num_iterations
self.weights = np.random.rand(output_dim[0], output_dim[1], input_dim)
def train(self, input_data):
for iteration in range(self.num_iterations):
for data_point in input_data:
bmu_idx = self.find_best_matching_unit(data_point)
self.update_weights(data_point, bmu_idx, iteration)
def find_best_matching_unit(self, data_point):
min_dist = np.inf
bmu_idx = (0, 0)
for i in range(self.output_dim[0]):
for j in range(self.output_dim[1]):
dist = np.linalg.norm(data_point - self.weights[i, j])
if dist < min_dist:
min_dist = dist
bmu_idx = (i, j)
return bmu_idx
def update_weights(self, data_point, bmu_idx, iteration):
for i in range(self.output_dim[0]):
for j in range(self.output_dim[1]):
dist_to_bmu = np.linalg.norm(np.array(bmu_idx) - np.array([i, j]))
lr = self.learning_rate * (1 - iteration / self.num_iterations)
influence = np.exp(-dist_to_bmu / (lr * iteration))
self.weights[i, j] += lr * influence * (data_point - self.weights[i, j])
def predict(self, input_data):
output = np.zeros(input_data.shape[0])
for i, data_point in enumerate(input_data):
bmu_idx = self.find_best_matching_unit(data_point)
output[i] = bmu_idx[0] * self.output_dim[1] + bmu_idx[1]
return output
# 示例用法
input_data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
som = SOM(input_dim=2, output_dim=(10, 10))
som.train(input_data)
test_data = np.array([[2.0, 3.0], [4.0, 5.0]])
predicted = som.predict(test_data)
print(predicted)
```
这里的SOM类有几个关键的方法和属性:
- `__init__`:初始化方法,定义了输入维度、输出维度、学习率和迭代次数,并初始化权重矩阵。
- `train`:训练方法,通过迭代更新权重矩阵。
- `find_best_matching_unit`:寻找最佳匹配单元方法,在输出层中寻找与输入数据点最相似的神经元。
- `update_weights`:更新权重方法,根据输入数据点和最佳匹配单元更新权重矩阵。
- `predict`:预测方法,根据输入数据点,返回最佳匹配单元的位置。
示例中使用一个简单的二维数据集进行训练,并通过`predict`方法预测测试数据点的最佳匹配单元位置。
这是一个简单的som神经网络python源代码示例,可以根据需要进行修改和扩展。
阅读全文