Advertisement
andreasoie

Calculating FID

Mar 22nd, 2023
708
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.63 KB | Source Code | 0 0
  1. """
  2. Copyright (c) 2023, Andreas Øie
  3. Copyright (c) 2020-present NAVER Corp.
  4.  
  5. This work is licensed under the Creative Commons Attribution-NonCommercial
  6. 4.0 International License. To view a copy of this license, visit
  7. http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
  8. Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
  9. """
  10.  
  11. import argparse
  12. import warnings
  13. from itertools import chain
  14. from pathlib import Path
  15.  
  16. import numpy as np
  17. import torch
  18. import torch.nn as nn
  19. from scipy import linalg
  20. from torchvision import models
  21.  
  22. warnings.filterwarnings('ignore')
  23.  
  24. from PIL import Image
  25. from torch.utils import data
  26. from torchvision import transforms
  27. from tqdm import tqdm
  28.  
  29.  
  30. def list_directory(dirname: str): return list(chain(*[list(Path(dirname).rglob('*.' + ext)) for ext in ['png', 'jpg', 'jpeg', 'JPG']]))
  31.  
  32. class DefaultDataset(data.Dataset):
  33.     def __init__(self, dirname: str, transform: transforms.Compose = None) -> None:
  34.         self.samples = list_directory(dirname)
  35.         self.samples.sort()
  36.         self.transform = transform
  37.         self.targets = None
  38.  
  39.     def __getitem__(self, index: int) -> torch.Tensor:
  40.         filename = self.samples[index]
  41.         img = Image.open(filename).convert('RGB')
  42.         if self.transform is not None:
  43.             img = self.transform(img)
  44.         return img
  45.  
  46.     def __len__(self):
  47.         return len(self.samples)
  48.    
  49. def get_eval_loader(img_folder: str, img_size: int = 256, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4, drop_last: bool = False) -> data.DataLoader:
  50.     mean = [0.5, 0.5, 0.5]
  51.     std = [0.5, 0.5, 0.5]
  52.    
  53.     transform = transforms.Compose([
  54.         transforms.Resize([img_size, img_size]),
  55.         transforms.ToTensor(),
  56.         transforms.Normalize(mean=mean, std=std)
  57.     ])
  58.    
  59.     return data.DataLoader(dataset=DefaultDataset(img_folder, transform=transform),
  60.                            batch_size=batch_size,
  61.                            shuffle=shuffle,
  62.                            num_workers=num_workers,
  63.                            pin_memory=True,
  64.                            drop_last=drop_last)
  65.  
  66. class InceptionV3(nn.Module):
  67.     def __init__(self) -> None:
  68.         super().__init__()
  69.         inception = models.inception_v3(pretrained=True)
  70.         self.block1 = nn.Sequential(
  71.             inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
  72.             inception.Conv2d_2b_3x3,
  73.             nn.MaxPool2d(kernel_size=3, stride=2))
  74.         self.block2 = nn.Sequential(
  75.             inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
  76.             nn.MaxPool2d(kernel_size=3, stride=2))
  77.         self.block3 = nn.Sequential(
  78.             inception.Mixed_5b, inception.Mixed_5c,
  79.             inception.Mixed_5d, inception.Mixed_6a,
  80.             inception.Mixed_6b, inception.Mixed_6c,
  81.             inception.Mixed_6d, inception.Mixed_6e)
  82.         self.block4 = nn.Sequential(
  83.             inception.Mixed_7a, inception.Mixed_7b,
  84.             inception.Mixed_7c,
  85.             nn.AdaptiveAvgPool2d(output_size=(1, 1)))
  86.  
  87.     def forward(self, x: torch.Tensor) -> torch.Tensor:
  88.         x = self.block1(x)
  89.         x = self.block2(x)
  90.         x = self.block3(x)
  91.         x = self.block4(x)
  92.         return x.view(x.size(0), -1)
  93.  
  94.  
  95. def frechet_distance(mu, cov, mu2, cov2):
  96.     cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
  97.     dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
  98.     return np.real(dist)
  99.  
  100. @torch.no_grad()
  101. def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
  102.     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  103.     inception = InceptionV3().eval().to(device)
  104.     loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
  105.     mu, cov = [], []
  106.     for loader in loaders:
  107.         actvs = []
  108.         for x in tqdm(loader, total=len(loader), desc='Inference'):
  109.             actv = inception(x.to(device))
  110.             actvs.append(actv)
  111.         actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
  112.         mu.append(np.mean(actvs, axis=0))
  113.         cov.append(np.cov(actvs, rowvar=False))
  114.     fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
  115.     return fid_value
  116.  
  117.  
  118. if __name__ == '__main__':
  119.     parser = argparse.ArgumentParser()
  120.     parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images')
  121.     parser.add_argument('--img_size', type=int, default=256, help='image resolution')
  122.     parser.add_argument('--batch_size', type=int, default=9, help='batch size to use')
  123.     args = parser.parse_args()
  124.     fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size)
  125.     print(f"FID: {fid_value}")
Tags: FID
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement