tf.scatter_nd转化为pytorch代码
时间: 2023-04-04 18:05:13 浏览: 96
可以使用PyTorch的scatter函数来实现tf.scatter_nd的功能。具体代码如下:
import torch
indices = torch.tensor([[0], [2]])
updates = torch.tensor([1, 2])
shape = (4,)
result = torch.zeros(shape)
result = result.scatter(0, indices, updates)
print(result)
这段代码的功能是将updates中的值按照indices中的索引位置,分别赋值给result中对应的位置。
相关问题
torch.tensor.scatter_转化为tensorflow代码
以下是将torch.tensor.scatter_转化为tensorflow代码的示例:
```python
import tensorflow as tf
# 假设有一个形状为(3, 4)的张量
x = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 创建一个形状为(3, 4)的零张量
y = tf.zeros([3, 4])
# 将x的第一行插入到y的第二行
y = tf.tensor_scatter_nd_update(y, [[1]], x[0:1])
# 将x的第二行插入到y的第一行和第三行
y = tf.tensor_scatter_nd_update(y, [[0], [2]], x[1:2])
# 将x的第三行插入到y的第二行
y = tf.tensor_scatter_nd_update(y, [[1]], x[2:3])
# 输出结果
print(y)
```
这段代码将torch.tensor.scatter_函数的功能转化为了tensorflow中的tensor_scatter_nd_update函数。它将一个张量的指定行插入到另一个张量的指定行。
def scatter(self): self.scatter_frame = Frame(self.init_window_name) # 创建一个竖直滚动条 scrollbar = Scrollbar(self.scatter_frame, orient=tk.VERTICAL) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.scatter_frame.place(x=10, y=460, width=750, height=310) fig = plt.figure() ax = fig.add_subplot(111) pd.plotting.scatter_matrix(self.df,alpha = 0.3,figsize = (10,10),grid = True,ax = ax) self.scatter_view = FigureCanvasTkAgg(fig,master = self.scatter_frame) self.scatter_view.draw() self.scatter_view.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1) # 将滚动条绑定到FigureCanvasTkAgg上 self.scatter_view.configure(yscrollcommand=scrollbar.set) scrollbar.config(command=self.scatter_view.yview)我想在self.scatter_view画布上加一个滚轮,让图片不要被压缩显示,怎么改代码?
你可以在创建FigureCanvasTkAgg时,将其参数resize设置为True,这样就可以自动调整大小,避免图片被压缩显示。同时,你需要在scatter_frame上添加鼠标滚动事件的绑定,用来处理滚轮事件。
修改后的代码如下:
```python
def scatter(self):
self.scatter_frame = Frame(self.init_window_name)
# 创建一个竖直滚动条
scrollbar = Scrollbar(self.scatter_frame, orient=tk.VERTICAL)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.scatter_frame.place(x=10, y=460, width=750, height=310)
fig = plt.figure()
ax = fig.add_subplot(111)
pd.plotting.scatter_matrix(self.df,alpha = 0.3,figsize = (10,10),grid = True,ax = ax)
# 将resize设置为True
self.scatter_view = FigureCanvasTkAgg(fig,master = self.scatter_frame, resize=True)
self.scatter_view.draw()
self.scatter_view.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
# 将滚动条绑定到FigureCanvasTkAgg上
self.scatter_view.configure(yscrollcommand=scrollbar.set)
scrollbar.config(command=self.scatter_view.yview)
# 添加鼠标滚动事件的绑定
self.scatter_view.mpl_connect('scroll_event', self.on_scroll)
def on_scroll(self, event):
# 获取当前滚轮的方向
if event.button == 'up':
direction = 1
elif event.button == 'down':
direction = -1
else:
direction = 0
# 根据滚轮方向调整缩放比例
scale = 1.1
if direction:
x, y = event.x, event.y
ax = self.scatter_view.figure.axes[0]
if direction > 0:
# 放大
ax.set_xlim(xdata - scale * (xdata - ax.get_xlim()[0]),
xdata + scale * (ax.get_xlim()[1] - xdata))
ax.set_ylim(ydata - scale * (ydata - ax.get_ylim()[0]),
ydata + scale * (ax.get_ylim()[1] - ydata))
else:
# 缩小
ax.set_xlim(xdata - scale * (xdata - ax.get_xlim()[0]),
xdata + scale * (ax.get_xlim()[1] - xdata))
ax.set_ylim(ydata - scale * (ydata - ax.get_ylim()[0]),
ydata + scale * (ax.get_ylim()[1] - ydata))
self.scatter_view.draw()
```