ghostconv代碼
时间: 2023-12-24 13:04:38 浏览: 158
以下是GhostConv的代码示例:
```python
import torch
import torch.nn as nn
class GhostConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False):
super(GhostConv, self).__init__()
self.primary_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
self.cheap_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=groups, bias=bias)
def forward(self, x):
x1 = self.primary_conv(x)
x2 = self.cheap_conv(x)
return torch.cat([x1, x2], dim=1)
# 使用示例
input = torch.randn(1, 3, 224, 224)
conv = GhostConv(3, 64, 3)
output = conv(input)
print(output.shape)
```
请注意,这只是GhostConv的一个简单示例,你可以根据自己的需求进行修改和扩展。
阅读全文