Advertisement
Guest User

Untitled

a guest
Jul 20th, 2017
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. import chainer.functions as F
  2. import chainer.links as L
  3. import numpy as np
  4. from chainer import Chain
  5. from chainer import cuda, training, Variable
  6. from chainer import datasets, iterators, optimizers, initializers
  7. from chainer.training import extensions
  8. from sklearn.datasets import load_iris
  9.  
  10.  
  11. class MLP(Chain):
  12. def __init__(self, in_size: int, hidden_size: int, nb_class: int,
  13. initialW=initializers.GlorotNormal(), initial_bias=initializers.Constant(0.0)):
  14. super(MLP, self).__init__()
  15.  
  16. self.in_size = in_size
  17. self.hidden_size = hidden_size
  18. self.nb_class = nb_class
  19.  
  20. with self.init_scope():
  21. self.l1 = L.Linear(in_size=None, out_size=hidden_size,
  22. initialW=initialW, initial_bias=initial_bias)
  23. self.l2 = L.Linear(in_size=None, out_size=hidden_size,
  24. initialW=initialW, initial_bias=initial_bias)
  25. self.l3 = L.Linear(in_size=None, out_size=nb_class,
  26. initialW=initialW, initial_bias=initial_bias)
  27.  
  28. def __call__(self, x: Variable) -> Variable:
  29. with cuda.get_device_from_array(x):
  30. h1 = F.relu(self.l1(x))
  31. h2 = F.relu(self.l2(h1))
  32. return self.l3(h2)
  33.  
  34.  
  35. iris = load_iris()
  36. data, target = iris['data'].astype(np.float32), iris['target'].astype(np.int32)
  37.  
  38. iris_dataset = datasets.TupleDataset(data, target)
  39.  
  40. train_iter = iterators.SerialIterator(iris_dataset, batch_size=5, repeat=True, shuffle=True)
  41. test_iter = iterators.SerialIterator(iris_dataset, batch_size=1, repeat=False, shuffle=False)
  42.  
  43. model = L.Classifier(MLP(4, 50, 3))
  44. optimizer = optimizers.Adam()
  45. optimizer.setup(model)
  46.  
  47. updater = training.StandardUpdater(train_iter, optimizer)
  48. trainer = training.Trainer(updater, stop_trigger=(30, 'epoch'))
  49.  
  50. entries = ['epoch', 'elapsed_time',
  51. 'main/loss', 'main/accuracy',
  52. 'dev/main/loss', 'dev/main/accuracy',
  53. 'test/main/loss', 'test/main/accuracy', ]
  54.  
  55. trainer.extend(extensions.Evaluator(test_iter, model, device=-1), name='test', trigger=(1, 'epoch'))
  56. trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
  57. trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
  58. trainer.extend(extensions.PrintReport(entries=entries))
  59.  
  60. if __name__ == '__main__':
  61. trainer.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement