(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data() f = plt.figure(figsize=(12, 7)) f.suptitle('Label Counts for a Sample of Clients') client_data = collections.OrderedDict() for i in range(6): client_data[f'client_{i}'] = ( train_images[i*1000:(i+1)*1000], train_labels[i*1000:(i+1)*1000]) plot_data = collections.defaultdict(list) for example in client_data[f'client_{i}']: label = example[0].numpy() #images, labels = example[] #label = labels.numpy() plot_data[label].append(label) for i in range(6): plt.subplot(2, 3, i+1) plt.title('Client {}'.format(i)) for j in range(10): plt.hist( plot_data[j], density=False, bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
时间: 2023-06-08 19:06:50 浏览: 133
fashion_mnist数据集
这段代码使用了TensorFlow中的fashion_mnist数据集,将训练集和测试集分别赋值给了(train_images, train_labels)和(test_images,_labels)。接着,使用matplotlib库绘制一个12*7的图像,标题为“一组客户的标签计数”。然后用OrderedDict创建了一个名为client_data的字典,其中包含6个键值对,每个键值对表示一个客户的图片和标签。最后用defaultdict创建了一个名为plot_data的字典,用于收集每位客户的标签计数。
阅读全文