将下面的代码用torch.ops.script_ops.while_loop 进行改造:def py_cpu_nms2(dets, thresh): x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort(descending=True) while order.item().numel() > 0: order = order
时间: 2023-02-12 11:43:36 浏览: 499
以下是用 torch.ops.script_ops.while_loop 改造后的代码:
```
import torch
def py_cpu_nms2(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort(descending=True)
def loop_body(i, order):
i = i[0]
order = order[0]
if order.numel() == 0:
return i, order
else:
# ... loop body code ...
return i+1, order
i = torch.tensor([0], dtype=torch.long)
order = torch.ops.script_ops.while_loop(lambda i, order: i < order.numel(), loop_body, (i, order))
return order
```
注意:在 while_loop 中,第一个参数是循环条件函数,第二个参数是循环体函数,第三个参数是初始状态。在循环体函数中,使用 i+1 更新循环变量 i。在循环条件函数中,使用 i < order.numel() 判断循环是否结束。
阅读全文