Advertisement
SVXX

CIFAR10 Rotated Features for Multiclass

Jan 19th, 2023
1,373
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.18 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import torch.backends.cudnn as cudnn
  6.  
  7. import torchvision
  8. import torchvision.models as models
  9. import torchvision.transforms as transforms
  10.  
  11. import copy
  12.  
  13. import numpy as np
  14. import pandas as pd
  15. import matplotlib.pyplot as plt
  16. from sklearn.cluster import KMeans
  17. from sklearn.decomposition import PCA
  18. from sklearn.preprocessing import StandardScaler
  19. from sklearn.utils import shuffle
  20.  
  21. from helperFunctionsJoint import seed_torch
  22.  
  23. import os
  24. import argparse
  25.  
  26. from cifar10_models.resnet import resnet18
  27.  
  28.  
  29. if __name__ == "__main__":
  30.  
  31.     device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
  32.     print('device : ',device)
  33.  
  34.     seed_torch(2)
  35.  
  36.     transform_normal = transforms.Compose([
  37.  
  38.         transforms.CenterCrop(24),
  39.         torchvision.transforms.Resize((32,32)),
  40.         # torchvision.transforms.GaussianBlur(kernel_size = 5),
  41.         transforms.ToTensor(),
  42.         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)),
  43.     ])
  44.  
  45.     transform_rotate = transforms.Compose([
  46.  
  47.         transforms.RandomRotation((30,30.1)),
  48.         transforms.CenterCrop(24),
  49.         torchvision.transforms.Resize((32,32)),
  50.         # torchvision.transforms.GaussianBlur(kernel_size = 5),
  51.         transforms.ToTensor(),
  52.         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)),
  53.     ])
  54.  
  55.     trainset_normal = torchvision.datasets.CIFAR10(
  56.         root='./data_cifar', train=True, download=True, transform=transform_normal)
  57.     trainset_rotate = torchvision.datasets.CIFAR10(
  58.         root='./data_cifar', train=True, download=True, transform=transform_rotate)
  59.  
  60.     trainloader_normal = torch.utils.data.DataLoader(
  61.         trainset_normal, batch_size=100, shuffle=True, num_workers=2)
  62.     trainloader_rotate = torch.utils.data.DataLoader(
  63.         trainset_rotate, batch_size=100, shuffle=True, num_workers=2)
  64.  
  65.     testset_normal = torchvision.datasets.CIFAR10(
  66.         root='./data_cifar', train=False, download=True, transform=transform_normal)
  67.     testset_rotate = torchvision.datasets.CIFAR10(
  68.         root='./data_cifar', train=False, download=True, transform=transform_rotate)
  69.        
  70.     testloader_normal = torch.utils.data.DataLoader(
  71.         testset_normal, batch_size=100, shuffle=True, num_workers=2)
  72.     testloader_rotate = torch.utils.data.DataLoader(
  73.         testset_rotate, batch_size=100, shuffle=True, num_workers=2)
  74.  
  75.     classes = ('plane', 'car', 'bird', 'cat', 'deer',
  76.                'dog', 'frog', 'horse', 'ship', 'truck')
  77.  
  78.  
  79.     net = resnet18(pretrained=True)
  80.     # net = models.resnet18(pretrained=True)
  81.  
  82.     for param in net.parameters():
  83.         param.requires_grad = False
  84.  
  85.     net = net.to(device)
  86.     hidden_size = 128
  87.     dim2 = 8
  88.     n_components = 8192
  89.  
  90.     feature_extractor = torch.nn.Sequential(*list(net.children())[:-4])
  91.  
  92.     d = {}
  93.  
  94.     train_features = np.zeros((50000,hidden_size,dim2,dim2))
  95.     train_labels = np.zeros((50000))
  96.     test_features = np.zeros((10000,hidden_size,dim2,dim2))
  97.     test_labels = np.zeros((10000))
  98.  
  99.     for batch_idx, (inputs, targets) in enumerate(testloader_normal):
  100.         if(batch_idx >= 50): break    #__________10C...50%examples case.
  101.         inputs = inputs.to(device)
  102.         features = feature_extractor(inputs).squeeze()
  103.         test_features[batch_idx*200:batch_idx*200+100] = features.cpu().numpy()
  104.         test_labels[batch_idx*200:batch_idx*200+100] = targets.squeeze().cpu().numpy()
  105.  
  106.     for batch_idx, (inputs, targets) in enumerate(testloader_rotate):
  107.         if(batch_idx >= 50): break   #__________10C...50%examples case.
  108.         inputs = inputs.to(device)
  109.         features = feature_extractor(inputs).squeeze()
  110.         test_features[batch_idx*200+100:batch_idx*200+200] = features.cpu().numpy()
  111.         test_labels[batch_idx*200+100:batch_idx*200+200] = targets.squeeze().cpu().numpy()
  112.  
  113.     test_features,test_labels = shuffle(test_features,test_labels,random_state=0)
  114.  
  115.     d['resnet18_test_features'] = test_features
  116.     d['test_labels'] = test_labels
  117.  
  118.     for batch_idx, (inputs, targets) in enumerate(trainloader_normal):
  119.         if(batch_idx >= 125): break   #________10C 50%examples case.
  120.         inputs = inputs.to(device)
  121.         features = feature_extractor(inputs).squeeze()
  122.         train_features[batch_idx*200:batch_idx*200+100] = features.cpu().numpy()
  123.         train_labels[batch_idx*200:batch_idx*200+100] = targets.squeeze().cpu().numpy()
  124.  
  125.     for batch_idx, (inputs, targets) in enumerate(trainloader_rotate):
  126.         if(batch_idx >= 125): break    #________10C 50%examples case.
  127.         inputs = inputs.to(device)
  128.         features = feature_extractor(inputs).squeeze()
  129.         train_features[batch_idx*200+100:batch_idx*200+200] = features.cpu().numpy()
  130.         train_labels[batch_idx*200+100:batch_idx*200+200] = targets.squeeze().cpu().numpy()
  131.  
  132.     train_features,train_labels = shuffle(train_features,train_labels,random_state=0)
  133.  
  134.     d['resnet18_train_features'] = train_features
  135.     d['train_labels'] = train_labels
  136.  
  137.     for k,v in d.items():
  138.         print(k,v.shape)
  139.        
  140.     torch.save(d,'cifar-rotate_objclassify_without_pca_l4_10C_half_2.pth')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement