Advertisement
Guest User

utils.py

a guest
May 18th, 2022
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.35 KB | None | 0 0
  1. import os
  2.  
  3. import numpy as np
  4. import torch
  5. import torch.utils.data as data
  6. import torchvision
  7. import torchvision.transforms as transforms
  8.  
  9. from params import SEED
  10.  
  11. # Set seed
  12. torch.manual_seed(SEED)
  13. torch.cuda.manual_seed(SEED)
  14. torch.cuda.manual_seed_all(SEED)
  15. np.random.seed(SEED)
  16.  
  17. # Check folders
  18. for folder in ["data", "model", "log", "fig"]:
  19.     if not os.path.exists(f"./{folder}"):
  20.         os.mkdir(f"./{folder}")
  21.  
  22. # Device
  23. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  24.  
  25. # Data
  26. MEAN = (0.4942, 0.4851, 0.4504)
  27. STD = (0.2467, 0.2429, 0.2616)
  28. tf = transforms.Compose(
  29.     [
  30.         transforms.ToTensor(),
  31.         transforms.Normalize(MEAN, STD),
  32.     ]
  33. )
  34. # the following transforms are used for data augmentation
  35. # but here to keep it simple, we don't actually use them
  36. # it is reasonable to swap the default transforms with corresponding DA ones.
  37. train_tf = transforms.Compose(
  38.     [
  39.         transforms.RandomCrop(32, padding=4),
  40.         transforms.RandomHorizontalFlip(),
  41.         tf,
  42.     ]
  43. )
  44. test_tf = transforms.Compose([tf])
  45.  
  46. class_labels = [
  47.     "plane",
  48.     "car",
  49.     "bird",
  50.     "cat",
  51.     "deer",
  52.     "dog",
  53.     "frog",
  54.     "horse",
  55.     "ship",
  56.     "truck",
  57. ]
  58.  
  59. if not os.path.exists("./data/cifar-10-python.tar.gz"):
  60.     # Just download the data if not already exists and omit the notification
  61.     _ = torchvision.datasets.CIFAR10(root="./data", download=True, transform=tf)
  62. train_set = torchvision.datasets.CIFAR10(root="./data", train=True, transform=tf)
  63. test_set = torchvision.datasets.CIFAR10(root="./data", train=False, transform=tf)
  64.  
  65.  
  66. def get_train_dataloader(batch_size, shuffle=True):
  67.     global train_set
  68.     return data.DataLoader(
  69.         train_set, batch_size=batch_size, shuffle=shuffle, num_workers=0
  70.     )
  71.  
  72.  
  73. def get_test_dataloader(batch_size, shuffle=False):
  74.     global test_set
  75.     return data.DataLoader(
  76.         test_set, batch_size=batch_size, shuffle=shuffle, num_workers=0
  77.     )
  78.  
  79.  
  80. def get_train_test_dataloader(batch_size):
  81.     global train_set, test_set
  82.     return get_train_dataloader(batch_size), get_test_dataloader(batch_size)
  83.  
  84.  
  85. # Logger
  86. class Logger:
  87.     """
  88.    Takes a file path when initialized
  89.  
  90.    print and write to file for anything passed to it
  91.    """
  92.  
  93.     def __init__(self, file_path) -> None:
  94.         self.file_path = file_path
  95.         print(f"{DEVICE=}, Pytorch: {torch.__version__}, Seed: {SEED}")
  96.         with open(file_path, "w") as f:
  97.             f.write(f"{DEVICE=}, Pytorch: {torch.__version__}, Seed: {SEED}\n")
  98.  
  99.     def __call__(self, message) -> None:
  100.         print(message)
  101.         with open(self.file_path, "a") as f:
  102.             f.write(message + "\n")
  103.  
  104.  
  105. if __name__ == "__main__":
  106.     print(f"{DEVICE=}, Pytorch: {torch.__version__}, Seed: {SEED}")
  107.     print(f"{len(train_set)} train samples, {len(test_set)} test samples")
  108.     # Only used to check if data transformation is correct
  109.     print(
  110.         f"Test mean: {next(iter(get_test_dataloader(len(test_set))))[0].mean(dim=(0,2,3))}"
  111.     )
  112.     print(
  113.         f"Test std: {next(iter(get_test_dataloader(len(test_set))))[0].std(dim=(0,2,3))}"
  114.     )
  115.     print(
  116.         f"Train mean: {next(iter(get_test_dataloader(len(train_set))))[0].mean(dim=(0,2,3))}"
  117.     )
  118.     print(
  119.         f"Train std: {next(iter(get_test_dataloader(len(train_set))))[0].std(dim=(0,2,3))}"
  120.     )
  121.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement