Advertisement
Guest User

Untitled

a guest
May 26th, 2023
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.38 KB | None | 0 0
  1. import torch
  2. import argparse
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from torch.optim.lr_scheduler import StepLR
  7.  
  8. class CudaEventTimer(object):
  9.     def __init__(self, start_event: torch.cuda.Event, end_event: torch.cuda.Event):
  10.         self.start_event = start_event
  11.         self.end_event = end_event
  12.  
  13.     def get_elapsed_msec(self):
  14.         torch.cuda.current_stream().wait_event(self.end_event)
  15.         self.end_event.synchronize()
  16.         return self.start_event.elapsed_time(self.end_event)
  17.  
  18.  
  19. class Timer:
  20.     """Timer."""
  21.     def __init__(self):
  22.         self.started_ = False
  23.         self.event_timers = []
  24.         self.start_event = None
  25.         self.elapsed_records = None
  26.  
  27.     def start(self):
  28.         """Start the timer."""
  29.         self.start_event = torch.cuda.Event(enable_timing=True)
  30.         self.start_event.record()
  31.         self.started_ = True
  32.  
  33.     def stop(self, reset=False, record=False):
  34.         """Stop the timer."""
  35.         assert self.started_, "timer is not started"
  36.         end_event = torch.cuda.Event(enable_timing=True)
  37.         end_event.record()
  38.         self.event_timers.append(CudaEventTimer(self.start_event, end_event))
  39.         self.start_event = None
  40.         self.started_ = False
  41.  
  42.     def _get_elapsed_msec(self):
  43.         self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
  44.         self.event_timers.clear()
  45.         return sum(self.elapsed_records)
  46.  
  47.     def reset(self):
  48.         """Reset timer."""
  49.         self.started_ = False
  50.         self.start_event = None
  51.         self.elapsed_records = None
  52.         self.event_timers.clear()
  53.  
  54.     def elapsed(self, reset=True):
  55.         """Calculate the elapsed time."""
  56.         started_ = self.started_
  57.         # If the timing in progress, end it first.
  58.         if self.started_:
  59.             self.stop()
  60.         # Get the elapsed time.
  61.         elapsed_ = self._get_elapsed_msec()
  62.         # Reset the elapsed time
  63.         if reset:
  64.             self.reset()
  65.         # If timing was in progress, set it back.
  66.         if started_:
  67.             self.start()
  68.         return elapsed_
  69.  
  70. args = None
  71. epoch_fwd_times = []
  72.  
  73. class Net(nn.Module):
  74.     def __init__(self):
  75.         super(Net, self).__init__()
  76.         self.conv1 = nn.Conv2d(1, 32, 3, 1)
  77.         self.conv2 = nn.Conv2d(32, 64, 3, 1)
  78.         self.dropout1 = nn.Dropout(0.25)
  79.         self.dropout2 = nn.Dropout(0.5)
  80.         self.fc1 = nn.Linear(9216, 128)
  81.         self.fc2 = nn.Linear(128, 10)
  82.         self.timer = Timer()
  83.  
  84.     def forward(self, x):
  85.         global epoch_fwd_times
  86.         self.timer.start()
  87.         x = self.conv1(x)
  88.         x = F.relu(x)
  89.         x = self.conv2(x)
  90.         x = F.relu(x)
  91.         x = F.max_pool2d(x, 2)
  92.         x = self.dropout1(x)
  93.         x = torch.flatten(x, 1)
  94.         x = self.fc1(x)
  95.         x = F.relu(x)
  96.         x = self.dropout2(x)
  97.         x = self.fc2(x)
  98.         output = F.log_softmax(x, dim=1)
  99.         self.timer.stop()
  100.         fwd_time = self.timer.elapsed(reset=True)
  101.         epoch_fwd_times.append(fwd_time)
  102.         return output
  103.  
  104.  
  105. def train(args, model, device, train_loader, optimizer, epoch):
  106.     global epoch_fwd_times
  107.     model.train()
  108.     target = torch.LongTensor(args.batch_size).random_(10).to(device)
  109.     for batch_idx, images in enumerate(train_loader):
  110.         # data, target = data.to(device), target.to(device)
  111.         data = images.to(device, non_blocking=True)
  112.         optimizer.zero_grad()
  113.         output = model(data)
  114.         loss = F.nll_loss(output, target)
  115.         loss.backward()
  116.         optimizer.step()
  117.         if batch_idx % args.log_interval == 0:
  118.             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  119.                 epoch, batch_idx * len(data), len(train_loader.dataset),
  120.                 100. * batch_idx / len(train_loader), loss.item()))
  121.             if args.dry_run:
  122.                 break
  123.     print(f"Epoch avg fwd_time: {sum(epoch_fwd_times) / len(epoch_fwd_times)}")
  124.     epoch_fwd_times = []
  125.  
  126.  
  127. def test(model, device, test_loader):
  128.     model.eval()
  129.     test_loss = 0
  130.     correct = 0
  131.     with torch.no_grad():
  132.         for data, target in test_loader:
  133.             data, target = data.to(device), target.to(device)
  134.             output = model(data)
  135.             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
  136.             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
  137.             correct += pred.eq(target.view_as(pred)).sum().item()
  138.  
  139.     test_loss /= len(test_loader.dataset)
  140.  
  141.     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  142.         test_loss, correct, len(test_loader.dataset),
  143.         100. * correct / len(test_loader.dataset)))
  144.  
  145.  
  146. def main():
  147.     global args
  148.     # Training settings
  149.     parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  150.     parser.add_argument('--batch-size', type=int, default=64, metavar='N',
  151.                         help='input batch size for training (default: 64)')
  152.     parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
  153.                         help='input batch size for testing (default: 1000)')
  154.     parser.add_argument('--epochs', type=int, default=14, metavar='N',
  155.                         help='number of epochs to train (default: 14)')
  156.     parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
  157.                         help='learning rate (default: 1.0)')
  158.     parser.add_argument('--iterations', type=int, default=1000)
  159.     parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
  160.                         help='Learning rate step gamma (default: 0.7)')
  161.     parser.add_argument('--no-cuda', action='store_true', default=False,
  162.                         help='disables CUDA training')
  163.     parser.add_argument('--dry-run', action='store_true', default=False,
  164.                         help='quickly check a single pass')
  165.     parser.add_argument('--seed', type=int, default=1, metavar='S',
  166.                         help='random seed (default: 1)')
  167.     parser.add_argument('--log-interval', type=int, default=10, metavar='N',
  168.                         help='how many batches to wait before logging training status')
  169.     parser.add_argument('--save-model', action='store_true', default=False,
  170.                         help='For Saving the current Model')
  171.     args = parser.parse_args()
  172.     use_cuda = not args.no_cuda and torch.cuda.is_available()
  173.  
  174.     torch.manual_seed(args.seed)
  175.  
  176.     if use_cuda:
  177.         device = torch.device("cuda")
  178.     else:
  179.         device = torch.device("cpu")
  180.  
  181.     train_kwargs = {'batch_size': args.batch_size}
  182.     test_kwargs = {'batch_size': args.test_batch_size}
  183.     if use_cuda:
  184.         cuda_kwargs = {'num_workers': 1,
  185.                        'pin_memory': True,
  186.                        'shuffle': True}
  187.         train_kwargs.update(cuda_kwargs)
  188.         test_kwargs.update(cuda_kwargs)
  189.  
  190.     # transform=transforms.Compose([
  191.     #     transforms.ToTensor(),
  192.     #     transforms.Normalize((0.1307,), (0.3081,))
  193.     #     ])
  194.     # dataset1 = datasets.MNIST('../data', train=True, download=True,
  195.     #                    transform=transform)
  196.     # dataset2 = datasets.MNIST('../data', train=False,
  197.     #                    transform=transform)
  198.     # train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
  199.     # test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
  200.  
  201.     from torch.utils.data import Dataset, DataLoader
  202.     class RandomDataset(Dataset):
  203.         def __init__(self, length):
  204.             self.len = length
  205.             self.data = torch.randn(1, 28, 28, length)
  206.  
  207.         def __getitem__(self, index):
  208.             return self.data[:, :, :, index]
  209.  
  210.         def __len__(self):
  211.             return self.len
  212.     train_dataset = RandomDataset(args.batch_size * args.iterations)
  213.     train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
  214.  
  215.     model = Net().to(device)
  216.     optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
  217.  
  218.     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
  219.     for epoch in range(1, args.epochs + 1):
  220.         train(args, model, device, train_loader, optimizer, epoch)
  221.         # test(model, device, test_loader)
  222.         scheduler.step()
  223.  
  224.     if args.save_model:
  225.         torch.save(model.state_dict(), "mnist_cnn.pt")
  226.  
  227.  
  228. if __name__ == '__main__':
  229.     main()
  230.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement