Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- start_time = time.time()
- with torch.no_grad():
- best_network = Network()
- best_network.cuda()
- best_network.load_state_dict(torch.load('../moth_landmarks.pth'))
- best_network.eval()
- batch = next(iter(train_loader))
- images, landmarks = batch['image'], batch['landmarks']
- landmarks = landmarks.view(landmarks.size(0),-1).cuda()
- for i in range(8):
- if(i%2==0):
- landmarks[:,i] = landmarks[:,i]/800
- else:
- landmarks[:,i] = landmarks[:,i]/600
- landmarks [landmarks != landmarks] = 0
- #landmarks = landmarks.unsqueeze_(0)
- images = images.cuda()
- norm_image = transforms.Normalize(0.3812, 0.1123)
- print('images shape: ', images.shape)
- for image in images:
- image = image.unsqueeze_(1)
- #images = torch.cat((images,images,images),1)
- image = image.float()
- ##image = to_tensor(image) #TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
- image = norm_image(image)
- landmarks = (landmarks + 0.5) * 224 #??
- ##[8, 600, 800] --> [8,3,600,800]
- images = images.unsqueeze(1)
- images = torch.cat((images, images, images), 1)
- predictions = (best_network(images).cpu() + 0.5) * 224
- predictions = predictions.view(-1,4,2)
- plt.figure(figsize=(10,40))
- landmarks = landmarks.cpu()
- print(type(landmarks), landmarks.shape)
- for img_num in range(8):
- plt.subplot(8,1,img_num+1)
- plt.imshow(images[img_num].cpu().numpy().transpose(1,2,0).squeeze(), cmap='gray')
- plt.scatter(predictions[img_num,:,0], predictions[img_num,:,1], c = 'r')
- plt.scatter(landmarks[img_num,:,0], landmarks[img_num,:,1], c = 'g')
- print('Total number of test images: {}'.format(len(test_dataset)))
- end_time = time.time()
- print("Elapsed Time : {}".format(end_time - start_time))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement