Advertisement
SVXX

Three Rotation Classes

Jan 17th, 2023
917
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.17 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import torch.backends.cudnn as cudnn
  6.  
  7. import torchvision
  8. import torchvision.models as models
  9. import torchvision.transforms as transforms
  10.  
  11. import copy
  12.  
  13. import numpy as np
  14. import pandas as pd
  15. import matplotlib.pyplot as plt
  16. from sklearn.cluster import KMeans
  17. from sklearn.decomposition import PCA
  18. from sklearn.preprocessing import StandardScaler
  19. from sklearn.utils import shuffle
  20.  
  21. from helperFunctionsJoint import seed_torch
  22.  
  23. import os
  24. import argparse
  25. import random
  26.  
  27. from cifar10_models.resnet import resnet18
  28.  
  29.  
  30. if __name__ == '__main__':
  31.     device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
  32.     print("device : ", device)
  33.  
  34.     seed_torch(2)
  35.  
  36.     transform_normal = transforms.Compose([
  37.  
  38.         transforms.CenterCrop(24),
  39.         torchvision.transforms.Resize((32,32)),
  40.         # torchvision.transforms.GaussianBlur(kernel_size = 5),
  41.         transforms.ToTensor(),
  42.         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)),
  43.     ])
  44.  
  45.     transform_rotate = transforms.Compose([
  46.  
  47.         transforms.RandomRotation((30,30.1)),
  48.         transforms.CenterCrop(24),
  49.         torchvision.transforms.Resize((32,32)),
  50.         # torchvision.transforms.GaussianBlur(kernel_size = 5),
  51.         transforms.ToTensor(),
  52.         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)),
  53.     ])
  54.  
  55.     transform_rotate_two = transforms.Compose([
  56.  
  57.         transforms.RandomRotation((45,45.1)),
  58.         transforms.CenterCrop(24),
  59.         torchvision.transforms.Resize((32,32)),
  60.         # torchvision.transforms.GaussianBlur(kernel_size = 5),
  61.         transforms.ToTensor(),
  62.         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)),
  63.     ])
  64.  
  65.     trainset_normal = torchvision.datasets.CIFAR10(
  66.         root='./data_cifar', train=True, download=True, transform=transform_normal)
  67.     trainset_rotate = torchvision.datasets.CIFAR10(
  68.         root='./data_cifar', train=True, download=True, transform=transform_rotate)
  69.     trainset_rotate_two = torchvision.datasets.CIFAR10(
  70.         root='./data_cifar', train=True, download=True, transform=transform_rotate_two)
  71.  
  72.     trainloader_normal = torch.utils.data.DataLoader(
  73.         trainset_normal, batch_size=100, shuffle=True, num_workers=2)
  74.     trainloader_rotate = torch.utils.data.DataLoader(
  75.         trainset_rotate, batch_size=100, shuffle=True, num_workers=2)
  76.     trainloader_rotate_two = torch.utils.data.DataLoader(
  77.         trainset_rotate_two, batch_size=100, shuffle=True, num_workers=2)
  78.  
  79.     # #print(trainset_normal.data.shape)
  80.     # data = trainset_normal.data / 255 # data is numpy array
  81.     # mean = data.mean(axis = (0,1,2))
  82.     # std = data.std(axis = (0,1,2))
  83.     # print(f"Mean : {mean}   STD: {std}")
  84.     # #Mean : [0.491 0.482 0.447]   STD: [0.247 0.243 0.262]
  85.  
  86.  
  87.     testset_normal = torchvision.datasets.CIFAR10(
  88.         root='./data_cifar', train=False, download=True, transform=transform_normal)
  89.     testset_rotate = torchvision.datasets.CIFAR10(
  90.         root='./data_cifar', train=False, download=True, transform=transform_rotate)
  91.     testset_rotate_two = torchvision.datasets.CIFAR10(
  92.         root='./data_cifar', train=False, download=True, transform=transform_rotate_two)
  93.        
  94.     testloader_normal = torch.utils.data.DataLoader(
  95.         testset_normal, batch_size=100, shuffle=True, num_workers=2)
  96.     testloader_rotate = torch.utils.data.DataLoader(
  97.         testset_rotate, batch_size=100, shuffle=True, num_workers=2)
  98.     testloader_rotate_two = torch.utils.data.DataLoader(
  99.         testset_rotate_two, batch_size=100, shuffle=True, num_workers=2)
  100.  
  101.  
  102.     classes = ('unrotated','rotated')
  103.  
  104.  
  105.     net = resnet18(pretrained=True)
  106.     # net = models.resnet18(pretrained=True)
  107.  
  108.     for param in net.parameters():
  109.         param.requires_grad = False
  110.  
  111.     net = net.to(device)
  112.     hidden_size = 128
  113.     dim2 = 8
  114.     n_components = 8192
  115.  
  116.     feature_extractor = torch.nn.Sequential(*list(net.children())[:-4])
  117.  
  118.     d = {}
  119.  
  120.     # #_______________-Only 5classes case____________
  121.     # train_features = np.zeros((50000*2,hidden_size,dim2,dim2))
  122.     # train_labels = np.zeros((50000*2))
  123.     # train_original_labels = np.zeros((50000*2))
  124.  
  125.     # test_features = np.zeros((10000*2,hidden_size,dim2,dim2))
  126.     # test_labels = np.zeros((10000*2))
  127.     # test_original_labels = np.zeros((10000*2))
  128.  
  129.     #__________10classes....50% random examplescase________
  130.     train_features = np.zeros((50000,hidden_size,dim2,dim2))
  131.     train_labels = np.zeros((50000))
  132.  
  133.     test_features = np.zeros((10000,hidden_size,dim2,dim2))
  134.     test_labels = np.zeros((10000))
  135.  
  136.     for batch_idx, (inputs, targets) in enumerate(testloader_normal):
  137.         if(batch_idx >= 33): break    #__________10C...50%examples case.
  138.         inputs = inputs.to(device)
  139.         features = feature_extractor(inputs).squeeze()
  140.         test_features[batch_idx*300:batch_idx*300+100] = features.cpu().numpy()
  141.         test_labels[batch_idx*300:batch_idx*300+100] = np.zeros(100)
  142.         # test_original_labels[batch_idx*200:batch_idx*200+100] = targets.squeeze().cpu().numpy() #.....5C case
  143.     for batch_idx, (inputs, targets) in enumerate(testloader_rotate):
  144.         if(batch_idx >= 33): break   #__________10C...50%examples case.
  145.         inputs = inputs.to(device)
  146.         features = feature_extractor(inputs).squeeze()
  147.         test_features[batch_idx*300+100:batch_idx*300+200] = features.cpu().numpy()
  148.         test_labels[batch_idx*300+100:batch_idx*300+200] = np.ones(100)
  149.         # test_original_labels[batch_idx*200+100:batch_idx*200+200] = targets.squeeze().cpu().numpy() #...5C case.
  150.  
  151.     for batch_idx, (inputs, targets) in enumerate(testloader_rotate_two):
  152.         if(batch_idx >= 32): break   #__________10C...50%examples case.
  153.         inputs = inputs.to(device)
  154.         features = feature_extractor(inputs).squeeze()
  155.         test_features[batch_idx*300+200:batch_idx*300+300] = features.cpu().numpy()
  156.         test_labels[batch_idx*300+200:batch_idx*300+300] = (np.ones(100) + 1)
  157.  
  158.     # #________5C case ________
  159.     # test_features = test_features[test_original_labels < 5]
  160.     # test_labels = test_labels[test_original_labels < 5]
  161.  
  162.     test_features,test_labels = shuffle(test_features,test_labels,random_state=0)
  163.  
  164.     d['resnet18_test_features'] = test_features
  165.     d['test_labels'] = test_labels
  166.  
  167.     for batch_idx, (inputs, targets) in enumerate(trainloader_normal):
  168.         if(batch_idx >= 125): break   #________10C 50%examples case.
  169.         inputs = inputs.to(device)
  170.         features = feature_extractor(inputs).squeeze()
  171.         train_features[batch_idx*300:batch_idx*300+100] = features.cpu().numpy()
  172.         train_labels[batch_idx*300:batch_idx*300+100] = np.zeros(100)
  173.         # train_original_labels[batch_idx*200:batch_idx*200+100] = targets.squeeze().cpu().numpy()
  174.     for batch_idx, (inputs, targets) in enumerate(trainloader_rotate):
  175.         if(batch_idx >= 125): break    #________10C 50%examples case.
  176.         inputs = inputs.to(device)
  177.         features = feature_extractor(inputs).squeeze()
  178.         train_features[batch_idx*300+100:batch_idx*300+200] = features.cpu().numpy()
  179.         train_labels[batch_idx*300+100:batch_idx*300+200] = np.ones(100)
  180.         # train_original_labels[batch_idx*200+100:batch_idx*200+200] = targets.squeeze().cpu().numpy()
  181.  
  182.     for batch_idx, (inputs, targets) in enumerate(trainloader_rotate_two):
  183.         if(batch_idx >= 125): break    #________10C 50%examples case.
  184.         inputs = inputs.to(device)
  185.         features = feature_extractor(inputs).squeeze()
  186.         train_features[batch_idx*300+200:batch_idx*300+300] = features.cpu().numpy()
  187.         train_labels[batch_idx*300+200:batch_idx*300+300] = (np.ones(100) + 1)
  188.  
  189.  
  190.     # train_features = train_features[train_original_labels < 5]
  191.     # train_labels = train_labels[train_original_labels < 5]
  192.  
  193.     train_features,train_labels = shuffle(train_features,train_labels,random_state=0)
  194.  
  195.     d['resnet18_train_features'] = train_features
  196.     d['train_labels'] = train_labels
  197.  
  198.  
  199.     for k,v in d.items():
  200.         print(k,v.shape)
  201.  
  202.        
  203.     torch.save(d,'cifar-rotate_threeway_without_pca_l4_10C_half_2.pth')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement