JupyterLab写一个测量一张jpg格式的彩色图像FID的代码
时间: 2024-02-22 12:59:43 浏览: 172
好的,以下是一个使用 JupyterLab 编写的测量一张 jpg 格式的彩色图像 FID 的代码:
```python
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.datasets import STL10
from torchvision.datasets import ImageNet
from torchvision.datasets import CelebA
from torchvision.datasets import LSUN
from torch.nn.functional import adaptive_avg_pool2d
from torchvision.models import inception_v3
import numpy as np
import os
from scipy.stats import entropy
def calculate_activation_statistics(images, model, batch_size=50, dims=2048,
cuda=False):
model.eval()
act=np.empty((len(images), dims))
if cuda:
model.cuda()
start_idx = 0
while start_idx < len(images):
images_batch = torch.stack(images[start_idx:start_idx+batch_size])
if cuda:
images_batch = images_batch.cuda()
with torch.no_grad():
pred = model(images_batch)[0]
act[start_idx:start_idx+batch_size] = adaptive_avg_pool2d(pred, (1, 1)).squeeze(3).squeeze(2).cpu().data.numpy()
start_idx += batch_size
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = f'fid calculation produces singular product; adding {eps} to diagonal of cov estimates'
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError(f'Imaginary component {m}')
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
device = torch.device('cuda' if cuda else 'cpu')
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]).to(device)
m1, s1 = calculate_activation_statistics_from_paths(paths[0], model, batch_size, dims, cuda)
m2, s2 = calculate_activation_statistics_from_paths(paths[1], model, batch_size, dims, cuda)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
def calculate_activation_statistics_from_paths(path, model, batch_size, dims, cuda=False):
dataset = ImageFolder(path, transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
imgs = []
for batch_idx, (img, _) in enumerate(dataloader):
#print(f'Batch {batch_idx+1}/{len(dataloader)}')
imgs.append(img)
imgs = torch.cat(imgs, 0)
mu, sigma = calculate_activation_statistics(imgs, model, batch_size, dims, cuda)
return mu, sigma
def calculate_fid(img_real, img_fake, batch_size=50, cuda=False, dims=2048):
paths = [img_real, img_fake]
fid_value = calculate_fid_given_paths(paths, batch_size, cuda, dims)
return fid_value
img_real = 'path/to/real/image'
img_fake = 'path/to/fake/image'
fid = calculate_fid(img_real, img_fake)
print(f'FID score: {fid}')
```
注意:这个代码依赖于 torchvision 和 scipy 库,需要提前安装。另外,为了计算 FID,需要使用预训练的 InceptionV3 模型,该模型的权重可以在 torchvision.models.inception_v3() 中下载。
阅读全文