Advertisement
lamiastella

Untitled

Nov 10th, 2020
1,051
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.88 KB | None | 0 0
  1. start_time = time.time()
  2.  
  3. with torch.no_grad():
  4.  
  5.     best_network = Network()
  6.     best_network.cuda()
  7.     best_network.load_state_dict(torch.load('../moth_landmarks.pth'))
  8.     best_network.eval()
  9.    
  10.     batch = next(iter(train_loader))
  11.     images, landmarks = batch['image'], batch['landmarks']
  12.     landmarks = landmarks.view(landmarks.size(0),-1).cuda()
  13.  
  14.            
  15.     for i in range(8):
  16.         if(i%2==0):
  17.             landmarks[:,i] = landmarks[:,i]/800
  18.         else:
  19.             landmarks[:,i] = landmarks[:,i]/600
  20.     landmarks [landmarks != landmarks] = 0
  21.     #landmarks = landmarks.unsqueeze_(0)
  22.  
  23.     images = images.cuda()
  24.  
  25.     norm_image = transforms.Normalize(0.3812, 0.1123)
  26.     print('images shape: ', images.shape)
  27.     for image in images:
  28.         image = image.unsqueeze_(1)
  29.  
  30.         #images = torch.cat((images,images,images),1)
  31.         image = image.float()
  32.         ##image = to_tensor(image) #TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
  33.         image = norm_image(image)
  34.     landmarks = (landmarks + 0.5) * 224 #??
  35.    
  36.     ##[8, 600, 800] --> [8,3,600,800]
  37.     images = images.unsqueeze(1)
  38.     images = torch.cat((images, images, images), 1)
  39.  
  40.     predictions = (best_network(images).cpu() + 0.5) * 224
  41.     predictions = predictions.view(-1,4,2)
  42.    
  43.     plt.figure(figsize=(10,40))
  44.     landmarks = landmarks.cpu()
  45.     print(type(landmarks), landmarks.shape)
  46.     for img_num in range(8):
  47.         plt.subplot(8,1,img_num+1)
  48.         plt.imshow(images[img_num].cpu().numpy().transpose(1,2,0).squeeze(), cmap='gray')
  49.         plt.scatter(predictions[img_num,:,0], predictions[img_num,:,1], c = 'r')
  50.         plt.scatter(landmarks[img_num,:,0], landmarks[img_num,:,1], c = 'g')
  51.  
  52. print('Total number of test images: {}'.format(len(test_dataset)))
  53.  
  54. end_time = time.time()
  55. print("Elapsed Time : {}".format(end_time - start_time))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement