nordlaender

theano mlp for and

Jan 2nd, 2016
309
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.24 KB | None | 0 0
  1. __docformat__ = 'restructedtext en'
  2.  
  3. import gzip
  4. import os
  5. import sys
  6. import timeit
  7.  
  8. import numpy
  9.  
  10. import theano
  11. import theano.tensor as T
  12.  
  13.  
  14. class Layer():
  15.     """
  16.    this is a layer in the mlp
  17.    it's not meant to predict the outcome hence it does not compute a loss.
  18.    apply the functions for negative log likelihood = cost on the output of the last layer
  19.    """
  20.  
  21.     def __init__(self, input, n_in, n_out):
  22.         self.W = theano.shared(
  23.                 value=numpy.zeros(
  24.                         (n_in, n_out),
  25.                         dtype=theano.config.floatX
  26.                 ),
  27.                 name="W",
  28.                 borrow=True
  29.         )
  30.         self.b = theano.shared(
  31.                 value=numpy.zeros((n_in
  32.                                    , n_out),
  33.                                   dtype=theano.config.floatX),
  34.                 name="b",
  35.                 borrow=True
  36.         )
  37.  
  38.         self.output = T.nnet.softmax(T.dot(input, self.W) + self.b)
  39.         self.params = (self.W, self.b)
  40.         self.input = input
  41.  
  42.  
  43. def y_pred(output):
  44.     return T.argmax(output, axis=1)
  45.  
  46.  
  47. def negative_log_likelihood(output, y):
  48.     return -T.mean(T.log(output)[T.arange(y.shape[0]), y])
  49.  
  50.  
  51. def errors(output, y):
  52.     # check if y has same dimension of y_pred
  53.     if y.ndim != y_pred(output).ndim:
  54.         raise TypeError(
  55.                 'y should have the same shape as self.y_pred',
  56.                 ('y', y.type, 'y_pred', y_pred(output).type)
  57.         )
  58.     # check if y is of the correct datatype
  59.     if y.dtype.startswith('int'):
  60.         # the T.neq operator returns a vector of 0s and 1s, where 1
  61.         # represents a mistake in prediction
  62.         return T.mean(T.neq(y_pred(output), y))
  63.     else:
  64.         raise NotImplementedError()
  65.  
  66.  
  67. data_x = numpy.matrix([[0, 0],
  68.                        [1, 0],
  69.                        [0, 1],
  70.                        [1, 1]])
  71.  
  72. data_y = numpy.array([0,
  73.                       0,
  74.                       0,
  75.                       1])
  76.  
  77. train_set_x = theano.shared(numpy.asarray(data_x,
  78.                          dtype=theano.config.floatX),
  79.                          borrow=True)
  80.  
  81. train_set_y = T.cast(theano.shared(numpy.asarray(data_y,
  82.                          dtype=theano.config.floatX),
  83.                          borrow=True),"int32")
  84.  
  85. x = T.vector("x",theano.config.floatX)  # data
  86. y = T.ivector("y")  # labels
  87.  
  88. classifier = Layer(input=x, n_in=2, n_out=1)
  89.  
  90. cost = negative_log_likelihood(classifier.output, y)
  91.  
  92. g_W = T.grad(cost=cost, wrt=classifier.W)
  93. g_b = T.grad(cost=cost, wrt=classifier.b)
  94. index = T.lscalar()
  95.  
  96. learning_rate = 0.15
  97.  
  98. updates = [
  99.     (classifier.W, classifier.W - learning_rate * g_W),
  100.     (classifier.b, classifier.b - learning_rate * g_b)
  101. ]
  102.  
  103. train_model = theano.function(
  104.         inputs=[index],
  105.         outputs=cost,
  106.         updates=updates,
  107.         givens={
  108.             x: train_set_x[index],
  109.             y: train_set_y[index]
  110.         }
  111. )
  112. validate_model = theano.function(
  113.         inputs=[index],
  114.         outputs=classifier.errors(y),
  115.         givens={
  116.             x: train_set_x[index],
  117.             y: train_set_y[index]
  118.         }
  119. )
  120.  
  121. print(validate_model(3))
  122. for i in range(3):
  123.     train_model(i)
  124.  
  125. print(validate_model(3))
Add Comment
Please, Sign In to add comment