Guest User

Untitled

a guest
Sep 20th, 2018
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.17 KB | None | 0 0
  1. #import classes
  2. from torchfusion.learners import *
  3. from torchfusion.layers import *
  4. from torchfusion.datasets import fashionmnist_loader
  5. import torch.nn as nn
  6. from torchfusion.metrics import Accuracy
  7. from torch.optim import Adam
  8. from torchfusion.utils import VisdomLogger
  9.  
  10. #define the classifier
  11. network = nn.Sequential(
  12. Flatten(),
  13. nn.Linear(784,out_features=100),
  14. Swish(),
  15. Linear(100,100),
  16. Swish(),
  17. Linear(100,100),
  18. Swish(),
  19. Linear(100,10)
  20. )
  21.  
  22. #load the dataset
  23. train_set = fashionmnist_loader(size=28,batch_size=64)
  24. test_set = fashionmnist_loader(size=28,train=False,batch_size=64)
  25.  
  26. #setup the optimizer, loss function and learner
  27. loss_fn = nn.CrossEntropyLoss()
  28. optimizer = Adam(network.parameters())
  29.  
  30. learner = StandardLearner(network)
  31.  
  32. #create an instance of the visdom logger
  33. vis_logger = VisdomLogger()
  34.  
  35. #print a summary of the network
  36. print(learner.summary((1,28,28)))
  37.  
  38. #if not using tensorboard, omit the tensoboard_log arg
  39. if __name__ == "__main__":
  40. learner.train(train_loader=train_set,tensorboard_log="./tboard-logs",visdom_log=vis_logger,loss_fn=loss_fn,optimizer=optimizer,train_metrics=[Accuracy()],test_loader=test_set,test_metrics=[Accuracy()])
Add Comment
Please, Sign In to add comment