class Memory(): def __init__(self, capacity, dims): self.capacity = capacity self.mem = np.zeros((capacity, dims)) self.memory_counter = 0 '''存储记忆''' def store_transition(self, s, a, r, s_): tran = np.hstack((s, [a.squeeze(0), r], s_)) # 把s,a,r,s_困在一起,水平拼接 index = self.memory_counter % self.capacity # 除余得索引 self.mem[index, :] = tran # 给索引存值,第index行所有列都为其中一次的s,a,r,s_;mem会是一个capacity行,(s+a+r+s_)列的数组 self.memory_counter += 1 '''随机从记忆库里抽取''' def sample(self, n): assert self.memory_counter >= self.capacity, '记忆库没有存满记忆' sample_index = np.random.choice(self.capacity, n) # 从capacity个记忆里随机抽取n个为一批,可得到抽样后的索引号 new_mem = self.mem[sample_index, :] # 由抽样得到的索引号在所有的capacity个记忆中 得到记忆s,a,r,s_ return new_mem
时间: 2024-02-26 19:55:52 浏览: 75
WEB_Query.zip_WEB_Query.zip_vba url
5星 · 资源好评率100%
这是一个用于存储和随机抽样记忆的类,通常用于深度强化学习中的经验回放(Experience Replay)技术。
具体来说,这个类包含了以下几个方法:
- `__init__`方法:初始化记忆库的容量和每个记忆的维度,同时创建一个空的记忆库`mem`,并将`memory_counter`设为0。
- `store_transition`方法:将一次完整的记忆(包含当前状态`s`、动作`a`、奖励`r`和下一个状态`s_`)存储在记忆库`mem`中,同时将`memory_counter`加1。如果记忆库已经存满了,那么新的记忆将覆盖最早的记忆。
- `sample`方法:从记忆库`mem`中随机抽样n个记忆,其中`n`为抽样的批次大小。如果记忆库还没有存满`capacity`个记忆,那么会抛出异常。抽样后得到的记忆是一个n行,(s+a+r+s_)列的数组。
这个类的作用是为深度强化学习算法提供一个回放记忆的机制,使得算法可以重复利用之前的经验,提高学习效率和稳定性。
阅读全文