找不到torch.ops.script_ops.while_loop
时间: 2023-02-12 17:04:38 浏览: 159
这可能是因为您使用的 PyTorch 版本过低。while_loop 函数是 PyTorch 1.5 及更高版本中新增加的函数,如果您使用的版本低于此版本,可能会导致无法找到该函数。建议更新至最新版本。
相关问题
torch.ops.script_ops.while_loop这个怎么用
torch.ops.script_ops.while_loop是PyTorch中的一个函数,用于在脚本模式下执行循环。它接受三个参数:
1. cond: 循环条件,是一个函数,每次迭代时调用并返回一个布尔值。当返回值为True时继续循环,否则退出循环。
2. body: 循环体,是一个函数,每次迭代时调用。
3. loop_vars: 循环变量,是一个元组,代表循环中需要更新的变量。
例如:
```
import torch
def cond(i):
return i < 3
def body(i):
print(i)
return (i + 1,)
i = torch.tensor(0, dtype=torch.int32)
torch.ops.script_ops.while_loop(cond, body, (i,))
```
这个例子中,循环条件为i < 3, 循环体是打印出i, 循环变量是i, 结果是打印出0,1,2
将下面的代码用torch.ops.script_ops.while_loop 进行改造:def py_cpu_nms2(dets, thresh): x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort(descending=True) while order.item().numel() > 0: order = order
以下是用 torch.ops.script_ops.while_loop 改造后的代码:
```
import torch
def py_cpu_nms2(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort(descending=True)
def loop_body(i, order):
i = i[0]
order = order[0]
if order.numel() == 0:
return i, order
else:
# ... loop body code ...
return i+1, order
i = torch.tensor([0], dtype=torch.long)
order = torch.ops.script_ops.while_loop(lambda i, order: i < order.numel(), loop_body, (i, order))
return order
```
注意:在 while_loop 中,第一个参数是循环条件函数,第二个参数是循环体函数,第三个参数是初始状态。在循环体函数中,使用 i+1 更新循环变量 i。在循环条件函数中,使用 i < order.numel() 判断循环是否结束。
阅读全文