Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import mxnet as mx
- from mxnet import nd, autograd, gluon
- model_ctx = mx.gpu(0)
- # model_ctx = mx.cpu()
- batch_size = 128 # Batch basina dusen ornek sayisi
- epochs = 100 # Iterasyon sayisi
- # Yuklenen verinin X ve y kismina uygulanir.
- def transform_func(data, label):
- # Klasik MNIST verisi donusumu. Tip float yapilir ve
- # X degerleri 0 ile 1 arasinda sigdirilmak adina 255'e bolunur.
- return data.astype(np.float32)/255., label.astype(np.float32)
- train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True,\
- transform=transform_func), batch_size, shuffle=True)
- test_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=False,\
- transform=transform_func), batch_size, shuffle=True)
- m = gluon.nn.Sequential() # Model
- with m.name_scope(): # TF benzeri name_scope tanimi
- # Keras benzeri Sequential.add()
- m.add(gluon.nn.Dense(100,activation="relu"))
- m.add(gluon.nn.Dense(256,activation="relu"))
- m.add(gluon.nn.Dense(100,activation="relu"))
- m.add(gluon.nn.Dense(10))
- m.collect_params().initialize(mx.init.Normal(sigma=0.2), ctx=model_ctx)
- # Parametrelerin baslangic degerleri STD'si 0.2 olan bir Normal dagilimdan cekildi
- loss = gluon.loss.SoftmaxCrossEntropyLoss() # Siniflandirma hata fonksiyonu
- trainer = gluon.Trainer(m.collect_params(), "sgd", {"learning_rate": 0.01} )
- # Parametreleri SGD kullanarak 0.01 LR ile egit
- print m # Modeli yazdir
- # Egitim dongusu
- for e in range(epochs): # Her iterasyon icin
- for i, (data, label) in enumerate(train_data): # Her batch icin
- # X'i model contexte koy
- data = data.as_in_context(model_ctx).reshape([-1, 784])
- # y'i model contexte koy
- label = label.as_in_context(model_ctx)
- with autograd.record(): # Autograd icin name_scope
- output = m(data)
- l = loss(output, label) # Y ile net(X) arasindaki farki hesapla
- l.backward() # PyTorch stili Backward pass
- trainer.step(batch_size)
- print(m) # Modeli yazdir
- # Test Dongusu
- acc = mx.metric.Accuracy() # Dogruluk metrigi
- for i, (data, label) in enumerate(test_data): # Her batch icin
- data = data.as_in_context(model_ctx).reshape([-1, 784])
- label = label.as_in_context(model_ctx)
- output = m(data) # net(X) ciktisini hesapla
- preds = nd.argmax(output, axis=1)
- # Gelen olasilik degerlerinden en yuksek olanin indisini al
- # 3 sayi olsaydi 0 1 2 seklinde
- # olasiliklar [0.6, 0.4, 0.3] olsaydi
- # max=0.6
- # argmax = 0 olacakti.
- acc.update(preds=preds, labels=label) # Metrigi guncelle
- print acc # Dogrulugu yazdir
Add Comment
Please, Sign In to add comment