利用pytorch实现遗传算法,用于实现贪吃蛇游戏,请给出具体代码
时间: 2024-02-19 19:57:47 浏览: 115
好的,以下是一个简单的利用PyTorch实现遗传算法的贪吃蛇游戏示例代码:
```python
import random
import numpy as np
import torch
# 游戏相关参数
WIDTH = 20
HEIGHT = 20
FOOD_REWARD = 10
DEATH_PENALTY = -10
MOVE_PENALTY = -1
class SnakeGame:
def __init__(self):
self.width = WIDTH
self.height = HEIGHT
self.food_reward = FOOD_REWARD
self.death_penalty = DEATH_PENALTY
self.move_penalty = MOVE_PENALTY
self.reset()
def reset(self):
self.snake = [(0, 0)]
self.food = self.generate_food()
self.direction = 0
self.score = 0
self.steps = 0
def generate_food(self):
while True:
food = (random.randint(0, self.width-1), random.randint(0, self.height-1))
if food not in self.snake:
return food
def get_state(self):
head = self.snake[0]
left = (head[0]-1, head[1])
right = (head[0]+1, head[1])
up = (head[0], head[1]-1)
down = (head[0], head[1]+1)
state = [
int((left in self.snake or left[0] < 0) and self.direction != 1), # danger straight
int((right in self.snake or right[0] >= self.width) and self.direction != 0), # danger straight
int((up in self.snake or up[1] < 0) and self.direction != 3), # danger straight
int((down in self.snake or down[1] >= self.height) and self.direction != 2), # danger straight
int(self.direction == 0 and down in self.snake), # danger right
int(self.direction == 0 and up in self.snake), # danger left
int(self.direction == 1 and left in self.snake), # danger right
int(self.direction == 1 and up in self.snake), # danger left
int(self.direction == 2 and up in self.snake), # danger right
int(self.direction == 2 and right in self.snake), # danger left
int(self.direction == 3 and right in self.snake), # danger right
int(self.direction == 3 and down in self.snake), # danger left
self.food[0] - head[0], # food x distance
self.food[1] - head[1] # food y distance
]
return np.array(state, dtype=int)
def play_step(self, action):
self.steps += 1
reward = self.move_penalty
if action == 0:
new_head = (self.snake[0][0], self.snake[0][1]-1)
elif action == 1:
new_head = (self.snake[0][0], self.snake[0][1]+1)
elif action == 2:
new_head = (self.snake[0][0]-1, self.snake[0][1])
else:
new_head = (self.snake[0][0]+1, self.snake[0][1])
if new_head == self.food:
self.score += self.food_reward
self.snake.insert(0, new_head)
self.food = self.generate_food()
reward = self.food_reward
elif new_head[0] < 0 or new_head[0] >= self.width or new_head[1] < 0 or new_head[1] >= self.height or new_head in self.snake:
self.score += self.death_penalty
reward = self.death_penalty
self.reset()
else:
self.snake.insert(0, new_head)
self.snake.pop()
reward = self.move_penalty
self.direction = self.get_direction()
state = self.get_state()
done = False
return state, reward, done
def get_direction(self):
dx = self.snake[0][0] - self.snake[1][0]
dy = self.snake[0][1] - self.snake[1][1]
if dx == 0:
return 0 if dy == -1 else 1
else:
return 2 if dx == -1 else 3
class GeneticAlgorithm:
def __init__(self, population_size, mutation_rate, model_fn):
self.population_size = population_size
self.mutation_rate = mutation_rate
self.model_fn = model_fn
self.population = [model_fn() for _ in range(population_size)]
self.fitness = [0 for _ in range(population_size)]
def select(self):
parent1_idx = random.choices(range(self.population_size), weights=self.fitness)[0]
parent2_idx = random.choices(range(self.population_size), weights=self.fitness)[0]
return parent1_idx, parent2_idx
def crossover(self, parent1, parent2):
child1 = parent1.clone()
child2 = parent2.clone()
for param1, param2 in zip(child1.parameters(), child2.parameters()):
mask = torch.empty_like(param1).uniform_() < 0.5
param1[mask], param2[mask] = param2[mask], param1[mask]
return child1, child2
def mutate(self, model):
for param in model.parameters():
mask = torch.empty_like(param).uniform_() < self.mutation_rate
delta = torch.empty_like(param).normal_(0, 0.1)
param[mask] += delta[mask]
def evolve(self):
# 计算适应度
for i, model in enumerate(self.population):
fitness = self.evaluate(model)
self.fitness[i] = fitness
# 选择与繁殖新一代
new_population = []
for _ in range(self.population_size):
parent1_idx, parent2_idx = self.select()
parent1 = self.population[parent1_idx]
parent2 = self.population[parent2_idx]
child1, child2 = self.crossover(parent1, parent2)
self.mutate(child1)
self.mutate(child2)
new_population.extend([child1, child2])
self.population = new_population
def evaluate(self, model):
game = SnakeGame()
state = game.get_state()
done = False
fitness = 0
while not done:
action = model(torch.tensor(state).float().unsqueeze(0)).argmax(dim=1).item()
state, reward, done = game.play_step(action)
fitness += reward
return fitness
def train(self, num_generations):
for generation in range(num_generations):
self.evolve()
fitness = self.fitness
print(f"Generation: {generation}, Max Fitness: {max(fitness)}, Avg Fitness: {sum(fitness) / len(fitness)}")
if __name__ == "__main__":
# 定义模型
model = torch.nn.Sequential(
torch.nn.Linear(13, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 16),
torch.nn.ReLU(),
torch.nn.Linear(16, 4)
)
# 定义遗传算法
ga = GeneticAlgorithm(population_size=50, mutation_rate=0.1, model_fn=lambda: model.clone())
# 训练遗传算法
ga.train(num_generations=100)
```
代码中定义了一个 `SnakeGame` 类来实现贪吃蛇游戏,并实现了一个 `GeneticAlgorithm` 类来实现遗传算法。其中,`GeneticAlgorithm` 类的 `model_fn` 参数是一个函数,用于初始化模型,本例中使用了一个简单的前馈神经网络作为模型。在 `GeneticAlgorithm` 类的 `train` 方法中,首先计算每个个体的适应度,然后进行选择、交叉和变异,生成新一代个体。在训练过程中,会输出每一代的最大适应度和平均适应度。
阅读全文