NMOSFET

train

Apr 9th, 2022 (edited)
1,792
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.89 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torchvision
  6. import torchvision.transforms as transforms
  7. from model import MyNet
  8. from tqdm import tqdm
  9.  
  10. # GPU
  11. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  12.  
  13. # Cifar-10 data
  14. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
  15.  
  16. # Data
  17. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
  18. trainLoader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
  19.  
  20. # Data classes
  21. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  22.  
  23. net = MyNet().to(device)
  24.  
  25. # Parameters
  26. criterion = nn.CrossEntropyLoss()
  27. lr = 0.0001
  28. epochs = 100
  29. optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
  30.  
  31. def main():
  32. # Train
  33. for epoch in range(epochs):
  34. running_loss = 0.0
  35.  
  36. for times, data in enumerate(tqdm(trainLoader), 0):
  37. inputs, labels = data
  38. inputs, labels = inputs.to(device), labels.to(device)
  39.  
  40. # Zero the parameter gradients
  41. optimizer.zero_grad()
  42.  
  43. # forward + backward + optimize
  44. outputs = net(inputs)
  45. loss = criterion(outputs, labels)
  46. loss.backward()
  47. optimizer.step()
  48.  
  49. # print statistics
  50. running_loss += loss.item()
  51.  
  52. if times % 100 == 99 or times+1 == len(trainLoader):
  53. print('[%d/%d, %d/%d] loss: %.3f' % (epoch+1, epochs, times+1, len(trainLoader), running_loss/2000))
  54.  
  55. checkpoint = {"state_dict": net.state_dict(),}
  56. print("=> Saving checkpoint")
  57. torch.save(checkpoint, "modelWeight.pth.tar")
  58.  
  59. print('Finished Training')
  60.  
  61. if __name__ == "__main__":
  62. main()
  63.  
  64.  
Add Comment
Please, Sign In to add comment