yolov8载入预训练权重
时间: 2023-10-09 15:12:01 浏览: 289
yolov8的入门级(训练+预测)的代码demo
问题中提到的方法可以通过在yolov8模型的代码中插入一段代码来实现载入预训练权重。具体插入位置如下图所示:
```
def _new(self, m):
for i in m.names:
if isinstance(m.names[i], nn.BatchNorm2d):
m.names[i] = nn.SyncBatchNorm.convert_sync_batchnorm(m.names[i])
if isinstance(m, nn.SyncBatchNorm):
m = nn.SyncBatchNorm.convert_sync_batchnorm(m)
if isinstance(m, nn.Conv2d):
n = nn.Conv2d(m.in_channels, m.out_channels, m.kernel_size, m.stride, m.padding, m.dilation, m.groups)
n.weight.data = m.weight.data.clone()
n.bias.data = m.bias.data.clone()
n.scale = m.scale.clone()
return n
elif isinstance(m, nn.BatchNorm2d):
n = nn.BatchNorm2d(m.num_features, m.eps, m.momentum, m.affine, m.track_running_stats)
n.weight.data = m.weight.data.clone()
n.bias.data = m.bias.data.clone()
n.running_mean.data = m.running_mean.data.clone()
n.running_var.data = m.running_var.data.clone()
return n
elif isinstance(m, nn.Linear):
n = nn.Linear(m.in_features, m.out_features, m.bias is not None)
n.weight.data = m.weight.data.clone()
n.bias.data = m.bias.data.clone()
return n
else:
return m
def load_pretrained_weights(self, file, nclude=None):
msg = f"Transferred {n}/{n_items} items from {file}\n"
try:
# load
ckpt = torch.load(file, map_location=self.device)
# exclude keys
if include:
ckpt['model'] = {k: v for k, v in ckpt['model'].items() if k in include}
# load
self.model.load_state_dict(ckpt['model'], strict=False)
# report
print(msg, end='')
except KeyError as e:
s = 'missing keys in pretrained weights: {e}\n'
if allow_missing_keys:
print(f"{msg}warning: {s}")
else:
raise KeyError(f"{msg}error: {s}")
except Exception as e:
raise type(e)(f"{msg}{e}")
# 在yolov8->ultralytics->yolo->engine->model.py的def _new函数中插入以下代码
ckpt = torch.load("yolov8.pt")
csd = ckpt['model'].float().state_dict()
csd = intersect_dicts(csd, self.model.state_dict())
self.model.load_state_dict(csd, strict=False)
print(f'Transferred {len(csd)}/{len(self.model.state_dict())} items')
```
阅读全文