Advertisement
Guest User

CNN.py

a guest
May 18th, 2022
181
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.32 KB | None | 0 0
  1. import os
  2.  
  3. import torch
  4. import torch.nn as nn
  5.  
  6.  
  7. # LeNet-5 style model
  8. class CNN(nn.Module):
  9.     def __init__(self, num_classes=10):
  10.         super().__init__()
  11.         # Batch: N
  12.         # CIFAR-10: 3 * 32 * 32
  13.         self.conv1 = nn.Sequential(
  14.             # N, 3, 32, 32 => N, 6, 32, 32
  15.             nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=2),
  16.             nn.BatchNorm2d(6),
  17.             nn.ReLU(inplace=True),
  18.             # N, 6, 32, 32 => N, 6, 16, 16
  19.             nn.AvgPool2d(kernel_size=2, stride=2),
  20.         )
  21.         self.conv2 = nn.Sequential(
  22.             # N, 6, 16, 16 => N, 16, 12, 12
  23.             nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
  24.             nn.BatchNorm2d(16),
  25.             nn.ReLU(inplace=True),
  26.             # N, 16, 12, 12 => N, 16, 6, 6
  27.             nn.AvgPool2d(kernel_size=2, stride=2),
  28.         )
  29.         self.prediction = nn.Sequential(
  30.             # N, 16, 6, 6 => N, 16 * 6 * 6
  31.             nn.Flatten(),
  32.             # N, 16 * 6 * 6 => N, 120
  33.             nn.Linear(16 * 6 * 6, 120),
  34.             nn.ReLU(inplace=True),
  35.             # N, 120 => N, 84
  36.             nn.Linear(120, 84),
  37.             nn.ReLU(inplace=True),
  38.             # N, 84 => N, 10
  39.             nn.Linear(84, num_classes),
  40.         )
  41.  
  42.     def forward(self, X, *args):
  43.         out = self.conv1(X)
  44.         out = self.conv2(out)
  45.         out = self.prediction(out)
  46.         return out
  47.  
  48.  
  49. if __name__ == "__main__":
  50.     from time import time
  51.  
  52.     import pandas as pd
  53.     from tqdm import tqdm
  54.  
  55.     from params import (
  56.         ATTACK_ALPHA,
  57.         ATTACK_EPS,
  58.         ATTACK_ITER,
  59.         BATCH_SIZE,
  60.         LEARNING_RATE,
  61.         NUM_EPOCHS,
  62.     )
  63.     from PGD import PGD
  64.     from utils import Logger, get_train_test_dataloader
  65.  
  66.     DEVICE = torch.device("mps")
  67.  
  68.     # Hyperparameters
  69.     num_epochs = NUM_EPOCHS
  70.     learning_rate = LEARNING_RATE
  71.     batch_size = BATCH_SIZE
  72.  
  73.     # Data
  74.     train_loader, test_loader = get_train_test_dataloader(batch_size=batch_size)
  75.  
  76.     # Model
  77.     model = CNN().to(DEVICE)
  78.     optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
  79.     criterion = nn.CrossEntropyLoss()
  80.     logger = Logger("./log/CNN.log")
  81.  
  82.     # Training
  83.     logger(
  84.         f"Epochs: {num_epochs}, Batch size: {batch_size}, Learning rate: {learning_rate}"
  85.     )
  86.     loss_history = []
  87.     model.train()
  88.     for epoch in range(num_epochs):
  89.         start = time()
  90.         for i, (images, labels) in enumerate(train_loader):
  91.             images = images.to(DEVICE)
  92.             labels = labels.to(DEVICE)
  93.             optimizer.zero_grad()
  94.             outputs = model(images)
  95.             loss = criterion(outputs, labels)
  96.             loss.backward()
  97.             optimizer.step()
  98.             if (i + 1) % 100 == 0:
  99.                 loss_history.append(loss.item())
  100.                 logger(
  101.                     f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}"
  102.                 )
  103.         now = time()
  104.         logger(f"Epoch time elapsed: {now - start:.3f} sec")
  105.  
  106.     # Testing
  107.     total = 0
  108.     correct = 0
  109.     correct_adv = 0
  110.     success = 0
  111.     attacker = PGD(eps=ATTACK_EPS, alpha=ATTACK_ALPHA, num_iter=ATTACK_ITER)
  112.     model.eval()
  113.  
  114.     for images, labels in tqdm(test_loader):
  115.         images = images.to(DEVICE)
  116.         labels = labels.to(DEVICE)
  117.         total += labels.size(0)
  118.         # Model evaluation
  119.         outputs = model(images)
  120.         _, predicted = torch.max(outputs.data, 1)
  121.         correct += (predicted == labels).sum().item()
  122.         # Attack
  123.         images_adv = attacker.attack(model, images, labels)
  124.         outputs = model(images_adv)
  125.         _, predicted_adv = torch.max(outputs.data, 1)
  126.         correct_adv += (predicted_adv == labels).sum().item()
  127.         success -= (
  128.             torch.where(predicted == labels, predicted_adv == labels, False)
  129.             .sum()
  130.             .item()
  131.         )
  132.     success += correct
  133.     logger(f"Tested on {total} samples.")
  134.     logger(
  135.         f"Attack amplitude: {ATTACK_EPS:.3f}, step size: {ATTACK_ALPHA:.3f}, iterations: {ATTACK_ITER:d}"
  136.     )
  137.     logger(f"Original accuracy : {100 * correct / total:.2f}%")
  138.     logger(f"Adversarial accuracy : {100 * correct_adv / total:.2f}%")
  139.     logger(f"Adversarial success rate : {100 * success / correct:.2f}%")
  140.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement