Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- import torchvision
- import torchvision.transforms as transforms
- from model import MyNet
- from tqdm import tqdm
- # GPU
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- # Cifar-10 data
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
- # Data
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
- trainLoader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
- # Data classes
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- net = MyNet().to(device)
- # Parameters
- criterion = nn.CrossEntropyLoss()
- lr = 0.0001
- epochs = 100
- optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
- def main():
- # Train
- for epoch in range(epochs):
- running_loss = 0.0
- for times, data in enumerate(tqdm(trainLoader), 0):
- inputs, labels = data
- inputs, labels = inputs.to(device), labels.to(device)
- # Zero the parameter gradients
- optimizer.zero_grad()
- # forward + backward + optimize
- outputs = net(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- # print statistics
- running_loss += loss.item()
- if times % 100 == 99 or times+1 == len(trainLoader):
- print('[%d/%d, %d/%d] loss: %.3f' % (epoch+1, epochs, times+1, len(trainLoader), running_loss/2000))
- checkpoint = {"state_dict": net.state_dict(),}
- print("=> Saving checkpoint")
- torch.save(checkpoint, "modelWeight.pth.tar")
- print('Finished Training')
- if __name__ == "__main__":
- main()
Add Comment
Please, Sign In to add comment