Advertisement
Guest User

PGD.py

a guest
May 18th, 2022
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.37 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3.  
  4.  
  5. class PGD:
  6.     def __init__(
  7.         self, eps=10 / 255, alpha=2 / 255, num_iter=30, rand_init=True
  8.     ) -> None:
  9.         self.eps = eps
  10.         self.alpha = alpha
  11.         self.num_iter = num_iter
  12.         self.rand_init = rand_init
  13.         self.criterion = nn.CrossEntropyLoss()
  14.  
  15.     def attack(self, model, X, y):
  16.         X_adv = X.clone()
  17.         if self.rand_init:
  18.             # This is normally recommended for PGD attack,
  19.             # but doesn't do that a lot harm if discarded
  20.             X_adv = X_adv + torch.normal(mean=0, std=self.eps, size=X_adv.shape).to(
  21.                 X.device
  22.             )
  23.         for _ in range(self.num_iter):
  24.             X_adv.requires_grad = True
  25.             with torch.enable_grad():
  26.                 output = model(X_adv)
  27.                 loss = self.criterion(output, y)
  28.                 loss.backward()
  29.             # Update image and clear gradient
  30.             X_adv = X_adv.detach() + self.alpha * torch.sign(X_adv.grad)
  31.             X_adv = torch.clamp(X_adv, X - self.eps, X + self.eps)
  32.             X_adv = torch.clamp(X_adv, -2, 2)
  33.         X_adv.requires_grad = False
  34.         return X_adv
  35.  
  36.  
  37. if __name__ == "__main__":
  38.     import matplotlib.pyplot as plt
  39.     from tqdm import tqdm
  40.     import numpy as np
  41.  
  42.     from CNN import CNN
  43.     from params import ATTACK_ALPHA, ATTACK_EPS, ATTACK_ITER
  44.     from utils import DEVICE, class_labels, get_test_dataloader
  45.  
  46.     # Model & Data
  47.     model = CNN().to(DEVICE)
  48.     model.load_state_dict(torch.load("./model/CNN.ckpt", map_location=DEVICE))
  49.     model.eval()
  50.     attacker = PGD(eps=ATTACK_EPS, alpha=ATTACK_ALPHA, num_iter=ATTACK_ITER)
  51.     test_loader = get_test_dataloader(batch_size=100)
  52.  
  53.     # Evaluation of model and attacker
  54.     total = 0
  55.     correct = 0
  56.     correct_adv = 0
  57.     success = 0
  58.     success_histories = []
  59.     for images, labels in tqdm(test_loader):
  60.         images = images.to(DEVICE)
  61.         labels = labels.to(DEVICE)
  62.         total += labels.size(0)
  63.         # Model evaluation
  64.         outputs = model(images)
  65.         _, predicted = torch.max(outputs.data, 1)
  66.         correct += (predicted == labels).sum().item()
  67.         # Attack
  68.         adv_images = attacker.attack(model, images, labels)
  69.         outputs = model(adv_images)
  70.         _, predicted_adv = torch.max(outputs.data, 1)
  71.         correct_adv += (predicted_adv == labels).sum().item()
  72.         success -= (
  73.             torch.where(predicted == labels, predicted_adv == labels, False)
  74.             .sum()
  75.             .item()
  76.         )
  77.  
  78.         for i in range(images.size(0)):
  79.             # Record successful attacks
  80.             if (predicted[i] == labels[i]) and (predicted_adv[i] != labels[i]):
  81.                 success_histories.append(
  82.                     (
  83.                         images[i].detach().cpu().numpy(),
  84.                         adv_images[i].detach().cpu().numpy(),
  85.                         labels[i].detach().cpu().numpy(),
  86.                         predicted[i].detach().cpu().numpy(),
  87.                         predicted_adv[i].detach().cpu().numpy(),
  88.                     )
  89.                 )
  90.     success += correct
  91.  
  92.     # Visualization
  93.     vis_size = 5
  94.     success_histories = [
  95.         success_histories[i]
  96.         for i in np.random.choice(len(success_histories), vis_size, replace=False)
  97.     ]
  98.  
  99.     fig = plt.figure(figsize=(12, 6))
  100.     plt.tight_layout()
  101.     plt.title(
  102.         f"Model accuracy : {100 * correct / total:.2f}% -> {100 * correct_adv / total:.2f}%, "
  103.         f"Success rate : {100 * success / correct:.2f}% "
  104.         f"($\\epsilon={int(ATTACK_EPS *255)} / 255$)"
  105.     )
  106.     plt.axis("off")
  107.     for i, (image, image_adv, label, pred, pred_adv) in enumerate(success_histories):
  108.         ax = fig.add_subplot(2, vis_size, i + 1)
  109.         # Original image
  110.         ax.imshow(image.squeeze().transpose(1, 2, 0))
  111.         ax.set_title(f"{class_labels[int(label)]}")
  112.         ax.set_xlabel(f"{class_labels[int(pred)]}")
  113.         ax.set_xticks([])
  114.         ax.set_yticks([])
  115.         ax.set_aspect("equal")
  116.         # Adversarial image
  117.         ax = fig.add_subplot(2, vis_size, i + 1 + vis_size)
  118.         ax.imshow(image_adv.squeeze().transpose(1, 2, 0))
  119.         ax.set_xlabel(f"{class_labels[int(pred_adv)]}")
  120.         ax.set_xticks([])
  121.         ax.set_yticks([])
  122.         ax.set_aspect("equal")
  123.     fig.savefig("./fig/PGD.pdf")
  124.     # plt.show()
  125.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement