fid评价指标代码pytorch
时间: 2023-10-22 08:10:45 浏览: 227
以下是 PyTorch 中计算 FID 评价指标的代码:
```python
import torch
from torch.nn.functional import adaptive_avg_pool2d
from torchvision.models import inception_v3
from scipy import linalg
def calculate_activation_statistics(images, model):
model.eval()
act = model(images)[0]
act = adaptive_avg_pool2d(act, (1, 1)).squeeze(dim=2).squeeze(dim=2)
mu = act.mean(dim=0)
sigma = torch_cov(act, rowvar=False)
return mu, sigma
def torch_cov(m, rowvar=False):
if rowvar:
m = m.t()
# Subtract the mean over rows.
m -= m.mean(dim=1, keepdim=True)
# Compute the covariance.
fact = 1.0 / (m.size(1) - 1)
cov = fact * m.matmul(m.t()).squeeze()
return cov
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
diff = mu1 - mu2
covmean, _ = linalg.sqrtm(sigma1.mm(sigma2), disp=False)
if not np.isfinite(covmean).all():
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).mm(sigma2 + offset))
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def calculate_fid(images_real, images_fake, model):
mu1, sigma1 = calculate_activation_statistics(images_real, model)
mu2, sigma2 = calculate_activation_statistics(images_fake, model)
fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
return fid
# Example usage:
# images_real and images_fake should be PyTorch tensors of shape (N, C, H, W)
# where N is the number of images, C is the number of channels, and H and W are the height and width of the images.
# model is an instance of the InceptionV3 model from torchvision.models.
fid = calculate_fid(images_real, images_fake, model)
print("FID:", fid)
```
阅读全文