np.array 打乱数据
时间: 2023-09-25 13:16:45 浏览: 98
要打乱一个NumPy数组的数据,你可以使用`np.random.shuffle()`函数。这个函数会原地修改数组,将其元素随机重新排列。
下面是一个示例:
```python
import numpy as np
# 创建一个数组
arr = np.array([1, 2, 3, 4, 5])
# 打乱数组的顺序
np.random.shuffle(arr)
print(arr)
```
输出可能是:
```
[3 1 4 5 2]
```
注意,`np.random.shuffle()`函数只能打乱数组的第一个维度,如果是多维数组,只会打乱每个一维子数组的顺序。如果你想打乱多维数组的所有元素的顺序,可以先使用`np.reshape()`将其转换成一维数组,然后再使用`np.random.shuffle()`进行打乱。
相关问题
class PPOMemory: def __init__(self, batch_size): self.states = [] self.probs = [] self.vals = [] self.actions = [] self.rewards = [] self.dones = [] self.batch_size = batch_size def sample(self): batch_step = np.arange(0, len(self.states), self.batch_size) indices = np.arange(len(self.states), dtype=np.int64) np.random.shuffle(indices) batches = [indices[i:i+self.batch_size] for i in batch_step] return np.array(self.states),np.array(self.actions),np.array(self.probs),\ np.array(self.vals),np.array(self.rewards),np.array(self.dones),batches def push(self, state, action, probs, vals, reward, done): self.states.append(state) self.actions.append(action) self.probs.append(probs) self.vals.append(vals) self.rewards.append(reward) self.dones.append(done) def clear(self): self.states = [] self.probs = [] self.actions = [] self.rewards = [] self.dones = [] self.vals = []
这段代码是一个 PPO 算法中的经验回放(experience replay)类,用于存储和采样交互数据。具体来说,这个类有以下几个成员函数:
1. 构造函数 __init__:初始化经验回放缓存的大小 batch_size,并创建空的列表来存储状态、动作、概率、价值、奖励和完成标志等信息。
2. 采样函数 sample:将经验回放缓存中的数据按照 batch_size 分成若干个 batch,然后随机打乱数据索引,最后返回一个包含所有 batch 的列表。每个 batch 包含一个状态状态列表、一个动作列表、一个概率列表、一个价值列表、一个奖励列表和一个完成标志列表。
3. 存储函数 push:将交互数据(即一个状态 state、一个动作 action、一个概率 probs、一个价值 vals、一个奖励 reward 和一个完成标志 done)存储到经验回放缓存中。
4. 清空函数 clear:清空经验回放缓存,以便下一次使用。
整个经验回放类的作用是存储和采样交互数据,以便训练 PPO 算法时能够从多个交互轮次中有效地学习。其中,采样函数 sample 会将数据随机打乱,以避免过于相关的数据干扰训练。
def generate_arrays_from_file(lines,batch_size): n = len(lines) i = 0 while 1: X_train = [] Y_train = [] for _ in range(batch_size): if i==0: np.random.shuffle(lines) #-------------------------------------# # 读取输入图片并进行归一化和resize #-------------------------------------# name = lines[i].split(';')[0] img = Image.open("./dataset2/jpg/" + name) img = img.resize((WIDTH,HEIGHT), Image.BICUBIC) img = np.array(img)/255 X_train.append(img) #-------------------------------------# # 读取标签图片并进行归一化和resize #-------------------------------------# name = lines[i].split(';')[1].split()[0] label = Image.open("./dataset2/png/" + name) label = label.resize((int(WIDTH/2),int(HEIGHT/2)), Image.NEAREST) if len(np.shape(label)) == 3: label = np.array(label)[:,:,0] label = np.reshape(np.array(label), [-1]) one_hot_label = np.eye(NCLASSES)[np.array(label, np.int32)] Y_train.append(one_hot_label) i = (i+1) % n yield (np.array(X_train), np.array(Y_train))
这个函数是一个用于生成训练数据的生成器函数。它从一个包含了文件路径和标签的文本文件中读取数据,并将其处理成模型可以使用的格式。
函数的参数包括 `lines` 和 `batch_size`。`lines` 是一个包含了文件路径和标签的文本文件的行列表,`batch_size` 是每个批次中的样本数量。
函数使用了一个无限循环 `while 1`,这意味着它会不断地生成数据。在每个循环迭代中,函数会重新初始化 `X_train` 和 `Y_train` 空列表,用于存储当前批次的训练数据。
然后,它使用一个 `for` 循环来迭代 `batch_size` 次,从 `lines` 中读取训练样本。如果 `i` 的值为 0,表示已经遍历完了所有行,这时需要重新打乱 `lines` 的顺序,以保证随机性。
接下来,代码根据行中的内容读取输入图片和标签图片。它首先从行中提取出输入图片的文件名,并使用 `Image.open` 方法打开图片文件。然后,使用 `resize` 方法将图片调整为指定的宽度和高度,并使用 `Image.BICUBIC` 进行插值。之后,将图片转换为 NumPy 数组,并进行归一化处理(除以 255)。最后,将处理后的图片添加到 `X_train` 列表中。
类似地,代码从行中提取出标签图片的文件名,并使用类似的方式处理标签图片。不过,这里还有一个附加步骤,它将标签图片进行了降采样(resize)操作,并将其转换为 NumPy 数组。如果标签图片是 RGB 彩色图像,代码会将其转换为灰度图像。然后,将标签图像重新形状为一维数组,并使用 one-hot 编码的方式将其转换为 one-hot 标签向量。最后,将处理后的标签向量添加到 `Y_train` 列表中。
在每次迭代结束时,更新 `i` 的值,以便下一次迭代可以读取下一行的数据。通过 `yield` 关键字,函数将生成当前批次的训练数据,并返回给调用者。这意味着在使用该函数生成数据时,可以使用 `next()` 函数来获取下一个批次的训练数据。
总之,该函数通过迭代读取文本文件中的数据,并将其处理成模型可以使用的格式,然后以生成器的方式返回给调用者,用于训练模型。