Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #import classes
- from torchfusion.learners import *
- from torchfusion.layers import *
- from torchfusion.datasets import fashionmnist_loader
- import torch.nn as nn
- from torchfusion.metrics import Accuracy
- from torch.optim import Adam
- from torchfusion.utils import VisdomLogger
- #define the classifier
- network = nn.Sequential(
- Flatten(),
- nn.Linear(784,out_features=100),
- Swish(),
- Linear(100,100),
- Swish(),
- Linear(100,100),
- Swish(),
- Linear(100,10)
- )
- #load the dataset
- train_set = fashionmnist_loader(size=28,batch_size=64)
- test_set = fashionmnist_loader(size=28,train=False,batch_size=64)
- #setup the optimizer, loss function and learner
- loss_fn = nn.CrossEntropyLoss()
- optimizer = Adam(network.parameters())
- learner = StandardLearner(network)
- #create an instance of the visdom logger
- vis_logger = VisdomLogger()
- #print a summary of the network
- print(learner.summary((1,28,28)))
- #if not using tensorboard, omit the tensoboard_log arg
- if __name__ == "__main__":
- learner.train(train_loader=train_set,tensorboard_log="./tboard-logs",visdom_log=vis_logger,loss_fn=loss_fn,optimizer=optimizer,train_metrics=[Accuracy()],test_loader=test_set,test_metrics=[Accuracy()])
Add Comment
Please, Sign In to add comment