Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import chainer.functions as F
- import chainer.links as L
- import numpy as np
- from chainer import Chain
- from chainer import cuda, training, Variable
- from chainer import datasets, iterators, optimizers, initializers
- from chainer.training import extensions
- from sklearn.datasets import load_iris
- class MLP(Chain):
- def __init__(self, in_size: int, hidden_size: int, nb_class: int,
- initialW=initializers.GlorotNormal(), initial_bias=initializers.Constant(0.0)):
- super(MLP, self).__init__()
- self.in_size = in_size
- self.hidden_size = hidden_size
- self.nb_class = nb_class
- with self.init_scope():
- self.l1 = L.Linear(in_size=None, out_size=hidden_size,
- initialW=initialW, initial_bias=initial_bias)
- self.l2 = L.Linear(in_size=None, out_size=hidden_size,
- initialW=initialW, initial_bias=initial_bias)
- self.l3 = L.Linear(in_size=None, out_size=nb_class,
- initialW=initialW, initial_bias=initial_bias)
- def __call__(self, x: Variable) -> Variable:
- with cuda.get_device_from_array(x):
- h1 = F.relu(self.l1(x))
- h2 = F.relu(self.l2(h1))
- return self.l3(h2)
- iris = load_iris()
- data, target = iris['data'].astype(np.float32), iris['target'].astype(np.int32)
- iris_dataset = datasets.TupleDataset(data, target)
- train_iter = iterators.SerialIterator(iris_dataset, batch_size=5, repeat=True, shuffle=True)
- test_iter = iterators.SerialIterator(iris_dataset, batch_size=1, repeat=False, shuffle=False)
- model = L.Classifier(MLP(4, 50, 3))
- optimizer = optimizers.Adam()
- optimizer.setup(model)
- updater = training.StandardUpdater(train_iter, optimizer)
- trainer = training.Trainer(updater, stop_trigger=(30, 'epoch'))
- entries = ['epoch', 'elapsed_time',
- 'main/loss', 'main/accuracy',
- 'dev/main/loss', 'dev/main/accuracy',
- 'test/main/loss', 'test/main/accuracy', ]
- trainer.extend(extensions.Evaluator(test_iter, model, device=-1), name='test', trigger=(1, 'epoch'))
- trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
- trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
- trainer.extend(extensions.PrintReport(entries=entries))
- if __name__ == '__main__':
- trainer.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement