som神经网络源代码 python
时间: 2023-11-07 14:03:27 浏览: 99
som神经网络是一种自组织映射网络,用于数据聚类和可视化。下面是一个使用Python编写的SOM神经网络的示例源代码:
```python
import numpy as np
import matplotlib.pyplot as plt
def get_initial_weights(input_size, output_size):
return np.random.rand(input_size, output_size)
def calculate_distance(input_vector, weight_vector):
return np.linalg.norm(input_vector - weight_vector)
def update_weights(input_vector, learning_rate, radius, weights, winner):
for i in range(weights.shape[1]):
distance = np.linalg.norm(np.array([i]) - winner)
if distance <= radius:
weights[:, i] += learning_rate * (input_vector - weights[:, i])
def train_som(input_data, output_size, learning_rate, num_epochs):
input_size = input_data.shape[1]
weights = get_initial_weights(input_size, output_size)
for epoch in range(num_epochs):
for input_vector in input_data:
distances = [calculate_distance(input_vector, weight_vector) for weight_vector in weights.T]
winner = np.argmin(distances)
update_weights(input_vector, learning_rate, radius, weights, winner)
return weights
# 示例使用
input_data = np.array([[0.2, 0.5], [0.8, 0.7], [0.4, 0.3], [0.9, 0.1]])
output_size = 5
learning_rate = 0.5
num_epochs = 10
weights = train_som(input_data, output_size, learning_rate, num_epochs)
plt.scatter(input_data[:, 0], input_data[:, 1])
for weight_vector in weights.T:
plt.scatter(weight_vector[0], weight_vector[1], c='r')
plt.show()
```
以上代码包括了SOM神经网络的训练过程和可视化结果的绘制。在示例中,输入数据是一个2维的数组,输出大小为5,学习率为0.5,进行10个周期的训练。训练后,可以看到输入数据点被归类到最近的神经元,并且神经元的位置相对合理。
阅读全文