for i in range(5): client_dataset = dataset.shuffle(buffer_size=len(train_images)) client_dataset = dataset.batch(batch_size=50) plot_data = collections.defaultdict(list) for example in client_dataset: plot_data[example['label'].numpy()].append(example['pixels'].numpy()) f = plt.figure(i, figsize=(12, 5)) f.suptitle("Client #{}'s Mean Image Per Label".format(i)) for j in range(10): mean_img = np.mean(plot_data[j], 0) plt.subplot(2, 5, j+1) plt.imshow(mean_img.reshape((28, 28))) plt.axis('off')运行该行代码时出现unhashable type: 'numpy.ndarray'错误该如何修改
时间: 2023-06-07 09:11:08 浏览: 95
数据集 {}".format(i)) for j in range(10): ax = f.add_subplot(2, 5, j+1) pixels = plot_data[j][0][0] pixels = np.array(pixels, dtype='uint8') reshaped_pixels = pixels.reshape((28, 28)) ax.imshow(reshaped_pixels, cmap='gray')
这段代码是一个循环,用于生成5个客户端数据集,每个数据集包含打乱顺序的训练图像,以及一个批次大小为50的图像批次。对于每个数据集,该代码会将类别标签和像素值存储在一个字典中,然后将这些值用于绘制一组10个子图像,分别表示10个不同的手写数字。在生成的每个图像中,每个数字将呈现为灰度图像。
相关问题
f = plt.figure(figsize=(12, 7)) f.suptitle('Label Counts for a Sample of Clients') for i in range(6): client_dataset = dataset.shuffle(buffer_size=len(train_images)) client_dataset = dataset.batch(batch_size=10) example = next(iter(client_dataset)) label = example['label'].numpy() unique_values, value_counts = np.unique(label, return_counts=True) plt.bar(unique_values, value_counts) plt.title('Client {}'.format(i)) plt.show()该段代码中如何修改实现使得输出的直方图呈现2*3的排列
把for循环中的i变量改为两个变量row和col,表示行和列的数量,然后用嵌套的for循环输出直方图,并在外层循环前面添加subplot方法,修改后的代码如下:
```
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
row = 2
col = 3
for i in range(row * col):
client_dataset = dataset.shuffle(buffer_size=len(train_images))
client_dataset = dataset.batch(batch_size=10)
example = next(iter(client_dataset))
label = example['label'].numpy()
unique_values, value_counts = np.unique(label, return_counts=True)
plt.subplot(row, col, i+1)
plt.bar(unique_values, value_counts)
plt.title('Client {}'.format(i))
plt.show()
```
f = plt.figure(figsize=(12, 7)) f.suptitle('Label Counts for a Sample of Clients') row = 2 col = 3 for i in range(row * col): client_dataset = dataset.shuffle(buffer_size=len(train_images)) client_dataset = dataset.batch(batch_size=10) example = next(iter(client_dataset)) label = example['label'].numpy() unique_values, value_counts = np.unique(label, return_counts=True) plt.subplot(row, col, i+1) plt.bar(unique_values, value_counts) plt.title('Client {}'.format(i)) plt.show()该段代码中如何修改实现直方图中每列数据颜色不一样
您可以按照如下方法修改代码,实现直方图中每列数据颜色不一样:
1. 导入 colormap 模块:`from matplotlib import cm`
2. 定义颜色序列:`colors = cm.rainbow(np.linspace(0, 1, row * col))`
3. 在 for 循环中,为每个子图指定一种颜色:`plt.bar(unique_values, value_counts, color=colors[i])`
修改后的代码如下:
```
from matplotlib import pyplot as plt
from matplotlib import cm
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
row = 2
col = 3
colors = cm.rainbow(np.linspace(0, 1, row * col))
for i in range(row * col):
client_dataset = dataset.shuffle(buffer_size=len(train_images))
client_dataset = dataset.batch(batch_size=10)
example = next(iter(client_dataset))
label = example['label'].numpy()
unique_values, value_counts = np.unique(label, return_counts=True)
plt.subplot(row, col, i + 1)
plt.bar(unique_values, value_counts, color=colors[i])
plt.title('Client {}'.format(i))
plt.show()
```
阅读全文