遗传算法优化bp神经网络代码
时间: 2023-08-23 13:04:41 浏览: 123
遗传算法可以用来优化BP神经网络的参数,例如权重和偏置。下面是一个简单的遗传算法优化BP神经网络的代码示例:
```python
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载数据集并进行预处理
iris = load_iris()
X = iris.data
y = iris.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)
# 定义神经网络类
class NeuralNetwork:
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.W1 = np.random.randn(input_size, hidden_size)
self.b1 = np.random.randn(hidden_size)
self.W2 = np.random.randn(hidden_size, output_size)
self.b2 = np.random.randn(output_size)
def forward(self, X):
self.z1 = np.dot(X, self.W1) + self.b1
self.a1 = np.tanh(self.z1)
self.z2 = np.dot(self.a1, self.W2) + self.b2
self.y_hat = np.exp(self.z2) / np.sum(np.exp(self.z2), axis=1, keepdims=True)
def compute_loss(self, y):
loss = -np.sum(y * np.log(self.y_hat))
return loss
def predict(self, X):
self.forward(X)
return np.argmax(self.y_hat, axis=1)
# 定义遗传算法类
class GeneticAlgorithm:
def __init__(self, population_size, mutation_rate):
self.population_size = population_size
self.mutation_rate = mutation_rate
def init_population(self, input_size, hidden_size, output_size):
population = []
for i in range(self.population_size):
nn = NeuralNetwork(input_size, hidden_size, output_size)
parameters = [nn.W1.flatten(), nn.b1, nn.W2.flatten(), nn.b2]
population.append(parameters)
return population
def fitness(self, nn, X, y):
nn.forward(X)
loss = nn.compute_loss(y)
return 1 / (1 + loss)
def select_parents(self, population, fitness_scores):
parents = []
for i in range(2):
idx = np.random.choice(len(population), size=5, replace=False)
parent_idx = idx[np.argmax(fitness_scores[idx])]
parents.append(population[parent_idx])
return parents
def crossover(self, parent1, parent2):
child = []
for i in range(len(parent1)):
if np.random.rand() > 0.5:
child.append(parent1[i])
else:
child.append(parent2[i])
return child
def mutate(self, child):
for i in range(len(child)):
if np.random.rand() < self.mutation_rate:
if type(child[i]) == np.ndarray:
child[i] += np.random.randn(*child[i].shape) * 0.1
else:
child[i] += np.random.randn() * 0.1
return child
def evolve(self, population, fitness_scores, X, y):
new_population = []
for i in range(self.population_size):
parent1 = self.select_parents(population, fitness_scores)
parent2 = self.select_parents(population, fitness_scores)
child = self.crossover(parent1, parent2)
child = self.mutate(child)
nn = NeuralNetwork(X.shape[1], 5, y.shape[1])
nn.W1 = child[0].reshape(nn.W1.shape)
nn.b1 = child[1]
nn.W2 = child[2].reshape(nn.W2.shape)
nn.b2 = child[3]
fitness = self.fitness(nn, X, y)
new_population.append(child)
fitness_scores[i] = fitness
return new_population, fitness_scores
# 初始化遗传算法
ga = GeneticAlgorithm(20, 0.2)
population = ga.init_population(X.shape[1], 5, y.shape[1])
# 训练神经网络
for i in range(50):
fitness_scores = np.zeros(len(population))
for j in range(len(population)):
nn = NeuralNetwork(X.shape[1], 5, y.shape[1])
nn.W1 = population[j][0].reshape(nn.W1.shape)
nn.b1 = population[j][1]
nn.W2 = population[j][2].reshape(nn.W2.shape)
nn.b2 = population[j][3]
fitness = ga.fitness(nn, X, y)
fitness_scores[j] = fitness
population, fitness_scores = ga.evolve(population, fitness_scores, X, y)
# 测试神经网络
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
nn = NeuralNetwork(X.shape[1], 5, y.shape[1])
nn.W1 = population[0][0].reshape(nn.W1.shape)
nn.b1 = population[0][1]
nn.W2 = population[0][2].reshape(nn.W2.shape)
nn.b2 = population[0][3]
y_pred = nn.predict(X_test)
accuracy = np.mean(y_pred == y_test.flatten())
print(f'Accuracy: {accuracy:.2f}')
```
在这个示例中,我们首先定义了一个简单的BP神经网络类`NeuralNetwork`,它包括一个输入层,一个隐藏层和一个输出层。然后,我们定义了一个遗传算法类`GeneticAlgorithm`,它包括了一个初始化种群的方法`init_population`,一个计算适应度的方法`fitness`,一个选择父代的方法`select_parents`,一个交叉的方法`crossover`,一个变异的方法`mutate`,以及一个进化的方法`evolve`。最后,我们使用遗传算法优化BP神经网络的参数,并在测试集上进行评估。
需要注意的是,这个示例代码只是一个简单的演示,实际上应该根据具体问题进行调整和优化。
阅读全文