Advertisement
Guest User

Untitled

a guest
Nov 29th, 2020
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.93 KB | None | 0 0
  1.  
  2. import os
  3.  
  4.  
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchvision
  10. from numpy.random import normal
  11. import ray
  12. from ray import tune
  13. from ray.tune.schedulers import ASHAScheduler
  14. import random
  15. from torchvision.transforms import transforms
  16.  
  17. random.seed(100)
  18.  
  19. transform = transforms.Compose([transforms.ToTensor(),
  20. transforms.Normalize((0.5,), (0.5,)),
  21. ])
  22.  
  23. xy_trainPT = torchvision.datasets.MNIST(
  24. root="./",
  25. train=True,
  26. download=True
  27. )
  28.  
  29. trainset = torchvision.datasets.MNIST(root="./",
  30. train=True,
  31. download=True,
  32. transform=transform)
  33.  
  34. originalSet = torchvision.datasets.MNIST(root="./",
  35. train=True,
  36. download=True,
  37. transform=transform)
  38.  
  39. noisyArr = []
  40. originalArr = []
  41.  
  42. for index, shape in enumerate(trainset):
  43. noisyArr.append(shape[0].squeeze(dim=0).numpy())
  44. originalArr.append(originalSet[0][0].squeeze(dim=0).numpy())
  45. if index == 30000:
  46. break
  47.  
  48. noisyArr = np.array(noisyArr)
  49. originalArr = np.array(originalArr)
  50. print('done loading data')
  51.  
  52. original = originalArr / 255
  53.  
  54. X_2 = noisyArr / 255
  55.  
  56. for i in range(len(X_2)):
  57. norm = abs(np.random.normal(0, 0.3, size=(28, 28)))
  58. X_2[i] = X_2[i] + norm
  59.  
  60. pixels = int(784)
  61.  
  62.  
  63. class autoencoder(nn.Module):
  64. def __init__(self, config):
  65. super(autoencoder, self).__init__()
  66.  
  67. size = 28
  68. kernel = config['convK']
  69. # print(f"kernal: {kernel}")
  70. stride = config['convS']
  71. # print(f"stride: {stride}")
  72. padding = config['convP']
  73. # print(f"padding: {padding}")
  74. poolK = config['poolK']
  75. poolS = config['poolS']
  76. finalOutput = config['actMap']
  77. self.conv1 = torch.nn.Conv2d(1, finalOutput, kernel_size=kernel, stride=stride, padding=padding)
  78. self.bn1 = torch.nn.BatchNorm2d(finalOutput)
  79. self.pool1 = torch.nn.MaxPool2d(stride=poolS, kernel_size=poolK)
  80.  
  81. def poolAdjust(originalSize, kernel=poolK, stride=poolS, dilation=1):
  82. return ((originalSize - (dilation * (kernel - 1)) - 1) // stride) + 1
  83.  
  84. def conv2d_size_out(size, kernel_size=kernel, stride=stride, padding=padding):
  85. return ((size + (padding * 2) - (kernel_size - 1) - 1) // stride) + 1
  86.  
  87. convw = poolAdjust(conv2d_size_out(size))
  88. convh = poolAdjust(conv2d_size_out(size))
  89.  
  90. self.linear_input_size = convw * convh * finalOutput
  91.  
  92. self.head = torch.nn.Linear(self.linear_input_size, pixels)
  93. self.flatten = torch.nn.Linear(self.linear_input_size, self.linear_input_size)
  94. self.func = torch.nn.Hardtanh()
  95. self.softMax2d = torch.nn.Softmax2d()
  96.  
  97. def forward(self, x):
  98. x = self.bn1(self.conv1(x))
  99. x = self.pool1(x)
  100. x = torch.nn.functional.relu(x)
  101. return self.head(x.view(x.size(0), -1))
  102.  
  103.  
  104. def train(config, checkpoint_dir=None, data=None):
  105.  
  106. # data = (X_2, original)
  107. loss_fn = torch.nn.MSELoss()
  108. model = autoencoder(config)
  109. optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9)
  110. maxIter = 50000
  111. batchAmount = config['batchSize']
  112.  
  113. if checkpoint_dir:
  114. checkpoint = os.path.join(checkpoint_dir, "checkpoint")
  115. model_state, optimizer_state = torch.load(checkpoint)
  116. model.load_state_dict(model_state)
  117. optimizer.load_state_dict(optimizer_state)
  118.  
  119. for t in range(maxIter):
  120. epoch_loss = 0
  121.  
  122. optimizer.zero_grad()
  123. idx = np.random.randint(data[0].shape[0], size=batchAmount) # bootstrapping a subset of the total samples
  124.  
  125. X_scaled = torch.unsqueeze(torch.from_numpy(data[0][idx, :]).float(), dim=1) # creating tensor for convultion
  126.  
  127. testValues = torch.from_numpy(
  128. np.reshape(data[1][idx, :],
  129. (batchAmount, -1))
  130. ).float() # creating a flattened array for testing
  131.  
  132. y_pred = model(X_scaled) # predict on the subset
  133. # print(y_pred.size())
  134. loss = loss_fn(testValues, y_pred) # get loss on subset
  135.  
  136. if t % (maxIter / 10) == 0:
  137. # print(t, loss.item())
  138. tune.report(score=loss.item())
  139. with tune.checkpoint_dir(step=t) as checkpoint_dir:
  140. path = os.path.join(checkpoint_dir, "checkpoint")
  141. torch.save(
  142. (model.state_dict(), optimizer.state_dict()), path)
  143.  
  144. loss.backward() # get gradient stuff
  145. optimizer.step() # optimize
  146. epoch_loss += loss.item()
  147.  
  148. def tunerTrain():
  149. ray.init(_memory=8000000000, object_store_memory=4000000000, _redis_max_memory=8000000000, num_cpus=5),
  150. resources_per_trial = {"cpu": 4, "extra_cpu": 1},
  151.  
  152. searchSpace = {
  153. 'lr': tune.loguniform(1e-4, 9e-1),
  154. 'actMap': tune.grid_search([1, 2]),
  155. 'convK': tune.choice([3, 5, 7, 9]),
  156. 'convS': tune.grid_search([1, 2]),
  157. 'convP': tune.choice([0, 1, 2, 3]),
  158. 'poolK': tune.choice([3, 5, 7, 9]),
  159. 'poolS': tune.grid_search([1, 2]),
  160. 'batchSize': tune.choice([2, 4, 8, 16, 32, 64, 128, 256]),
  161. }
  162.  
  163. analysis = tune.run(tune.with_parameters(train, data=[X_2, original]), num_samples=10, metric='score', mode='min',
  164. scheduler=ASHAScheduler(),
  165. config=searchSpace)
  166. dfs = analysis.trial_dataframes
  167. print(f"Best Config: {analysis.get_best_config('score', mode='min')}")
  168. df = analysis.results_df
  169. logdir = analysis.get_best_logdir("score", mode="min")
  170. print(f"dir of best: {logdir}")
  171. print(analysis.best_result)
  172. print(f"Best trial final score: {analysis.get_best_trial('score', mode='min')}")
  173.  
  174. tunerTrain()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement