has_torch_function_variadic 函数
时间: 2023-10-11 17:08:43 浏览: 439
`has_torch_function_variadic` 函数是 PyTorch 中一个用于判断是否存在可变参数版本的 `torch.Function` 的函数。具体来说, `torch.Function` 是一种可以被自动微分的操作(比如加法、乘法等),而 `has_torch_function_variadic` 函数则用于判断该操作是否存在可变参数版本。
该函数的定义如下:
```cpp
template <typename F, typename... Args>
constexpr bool has_torch_function_variadic_v =
detail::has_torch_function_variadic_impl<F, Args...>::value;
```
其中,`F` 表示要判断的操作,`Args...` 则表示该操作的参数列表。函数返回值为 `bool` 类型,表示该操作是否存在可变参数版本。
该函数的使用方法如下:
```cpp
if constexpr (torch::has_torch_function_variadic_v<torch::addmm_out, Tensor&, const Tensor&, const Tensor&, const Tensor&, const Scalar&>) {
torch::autograd::profiler::RecordFunction record("addmm_out");
return torch::addmm_out(result, mat, mat2, vec, beta, alpha);
} else {
torch::autograd::profiler::RecordFunction record("addmm");
result = beta * result + alpha * mat.mm(mat2);
result.add_(vec);
return result;
}
```
上述代码中,首先使用 `has_torch_function_variadic_v` 函数判断 `addmm_out` 函数是否存在可变参数版本,如果存在,则调用该函数并返回结果;否则,使用普通的方式计算结果。
阅读全文