import matplotlib.pyplot as plt similarity = [0.3940893515407449, 1.0116279478167544] # similarity of action divergence = [8.761, 24.455] # js diversity labels = ['w. cu', 'wo. cu'] plt.rcParams['axes.labelsize'] = 16 # xy轴label的size plt.rcParams['xtick.labelsize'] = 12 # x轴ticks的size plt.rcParams['ytick.labelsize'] = 14 # y轴ticks的size plt.rcParams['legend.fontsize'] = 12 # 图例的size # 设置柱形的间隔 width = 0.3 # 柱形的宽度 x1_list = [] x2_list = [] for i in range(len(similarity)): x1_list.append(i) x2_list.append(i + width) # 创建图层 fig, ax1 = plt.subplots() # 设置左侧Y轴对应的figure ax1.set_ylabel('Time per iteration (s)') ax1.set_ylim(0, 1.1) ax1.bar(x1_list, similarity, width=width, color='tab:red', align='edge') ax1.set_xticklabels(ax1.get_xticklabels()) # 设置共用的x轴 # 设置右侧Y轴对应的figure ax2 = ax1.twinx() ax2.set_ylabel('GPU Usage (GB)') ax2.set_ylim(0, 28) ax2.bar(x2_list, divergence, width=width, color='tab:blue', align='edge', tick_label=labels) plt.tight_layout() plt.savefig("similarity.png") plt.show()我这个代码怎么加图例
时间: 2024-02-02 16:03:43 浏览: 122
要在这个代码中添加图例,需要使用plt.legend()函数来创建图例。首先,在ax1.bar()和ax2.bar()函数中添加label参数,将每个柱形的标签设置为相应的名称。然后,在创建图例之前,需要将ax1和ax2的对象存储在变量中,以便在plt.legend()函数中使用。最后,使用plt.legend()函数并设置loc参数来指定图例的位置。
以下是修改后的代码:
```
import matplotlib.pyplot as plt
# 数据
similarity = [0.3940893515407449, 1.0116279478167544]
divergence = [8.761, 24.455]
labels = ['w. cu', 'wo. cu']
# 设置字体大小
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 12
# 柱形宽度和间隔
width = 0.3
x1_list = []
x2_list = []
for i in range(len(similarity)):
x1_list.append(i)
x2_list.append(i + width)
# 创建图层
fig, ax1 = plt.subplots()
# 设置左侧Y轴对应的figure
ax1.set_ylabel('Time per iteration (s)')
ax1.set_ylim(0, 1.1)
ax1.bar(x1_list, similarity, width=width, color='tab:red', align='edge', label='Similarity')
ax1.set_xticklabels(ax1.get_xticklabels())
# 设置右侧Y轴对应的figure
ax2 = ax1.twinx()
ax2.set_ylabel('GPU Usage (GB)')
ax2.set_ylim(0, 28)
ax2.bar(x2_list, divergence, width=width, color='tab:blue', align='edge', tick_label=labels, label='Divergence')
# 添加图例
ax1_legend = ax1.legend(loc='upper left')
ax2_legend = ax2.legend(loc='upper right')
# 显示图表
plt.tight_layout()
plt.savefig("similarity.png")
plt.show()
```
这将在图表中添加两个图例,分别显示“Similarity”和“Divergence”标签,并将它们与相应的柱形颜色匹配。注意,在添加图例之前,需要将ax1和ax2的对象存储在变量中,并在plt.legend()函数中分别使用它们来创建图例。
阅读全文