4ever_bored

updates.py

Oct 14th, 2016
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.33 KB | None | 0 0
  1. import theano
  2. import numpy as np
  3. from theano import tensor as T
  4.  
  5.  
  6. class SGD():
  7.     '''Stochastic gradient descent, with support for momentum,
  8.    learning rate decay, and Nesterov momentum.
  9.    # Arguments
  10.        lr: float >= 0. Learning rate.
  11.        momentum: float >= 0. Parameter updates momentum.
  12.        decay: float >= 0. Learning rate decay over each update.
  13.        nesterov: boolean. Whether to apply Nesterov momentum.
  14.    '''
  15.     def __init__(self, lr=0.01, momentum=0., decay=0.,
  16.                  nesterov=False, **kwargs):
  17.        
  18.         self.iterations = theano.shared(np.asarray(0.,
  19.                                                    dtype=theano.config.floatX))  # @UndefinedVariable
  20.         self.lr = theano.shared(np.asarray(lr, dtype=theano.config.floatX))  # @UndefinedVariable
  21.         self.momentum = theano.shared(np.asarray(momentum,
  22.                                                  dtype=theano.config.floatX))  # @UndefinedVariable
  23.         self.decay = theano.shared(np.asarray(decay,
  24.                                               dtype=theano.config.floatX))  # @UndefinedVariable
  25.         self.inital_decay = decay
  26.         self.nesterov = nesterov
  27.  
  28.     def get_updates(self, params, loss, grads):
  29.         self.updates = []
  30.  
  31.         lr = self.lr
  32.         if self.inital_decay > 0:
  33.             lr *= (1. / (1. + self.decay * self.iterations))
  34.             self.updates.append(self.iterations, self.iterations + 1)
  35.  
  36.         # momentum
  37.         shapes = [p.get_value(borrow=True, return_internal_type=True).shape
  38.                   for p in params]
  39.         moments = [theano.shared(np.zeros(shape, dtype=theano.config.floatX))  # @UndefinedVariable
  40.                    for shape in shapes]  # @UndefinedVariable
  41.         self.weights = [self.iterations] + moments
  42.         for p, g, m in zip(params, grads, moments):
  43.             v = self.momentum * m - lr * g  # velocity
  44.             self.updates.append((m, v))
  45.  
  46.             if self.nesterov:
  47.                 new_p = p + self.momentum * v - lr * g
  48.             else:
  49.                 new_p = p + v
  50.  
  51.             self.updates.append((p, new_p))
  52.         return self.updates
  53.  
  54.     def get_config(self):
  55.         config = {'lr': float(self.lr.get_value()),
  56.                   'momentum': float(self.momentum.get_value()),
  57.                   'decay': float(self.decay.get_value()),
  58.                   'nesterov': self.nesterov}
  59.         base_config = super(SGD, self).get_config()
  60.         return dict(list(base_config.items()) + list(config.items()))
  61.  
  62. class Adam():
  63.     '''Adam optimizer.
  64.  
  65.    Default parameters follow those provided in the original paper.
  66.  
  67.    # Arguments
  68.        lr: float >= 0. Learning rate.
  69.        beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
  70.        epsilon: float >= 0. Fuzz factor.
  71.  
  72.    # References
  73.        - [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
  74.    '''
  75.     def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
  76.                  epsilon=1e-8, decay=0., **kwargs):
  77.         super(Adam, self).__init__(**kwargs)
  78.         self.__dict__.update(locals())
  79.         self.iterations = theano.shared(np.asarray(0,
  80.                                         dtype=theano.config.floatX))  # @UndefinedVariable
  81.         self.lr = theano.shared(np.asarray(lr, theano.config.floatX)) # @UndefinedVariable
  82.         self.beta_1 = theano.shared(np.asarray(beta_1, theano.config.floatX)) # @UndefinedVariable
  83.         self.beta_2 = theano.shared(np.asarray(beta_2, theano.config.floatX)) # @UndefinedVariable
  84.         self.decay = theano.shared(np.asarray(decay, theano.config.floatX)) # @UndefinedVariable
  85.         self.inital_decay = decay
  86.  
  87.     def get_updates(self, params, loss, grads):
  88.         self.updates = [(self.iterations, self.iterations + 1)]
  89.  
  90.         lr = self.lr
  91.         if self.inital_decay > 0:
  92.             lr *= (1. / (1. + self.decay * self.iterations))
  93.  
  94.         t = self.iterations + 1
  95.         lr_t = lr * sqrt(1. - T.pow(self.beta_2, t)) / (1. - T.pow(self.beta_1, t))
  96.  
  97.         shapes = [p.get_value(borrow=True, return_internal_type=True).shape
  98.                   for p in params]
  99.         ms = [theano.shared(np.zeros(shape, dtype=theano.config.floatX))  # @UndefinedVariable
  100.                    for shape in shapes]
  101.         vs = [theano.shared(np.zeros(shape, dtype=theano.config.floatX))  # @UndefinedVariable
  102.                    for shape in shapes]
  103.         self.weights = [self.iterations] + ms + vs
  104.  
  105.         for p, g, m, v in zip(params, grads, ms, vs):
  106.             m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
  107.             v_t = (self.beta_2 * v) + (1. - self.beta_2) * T.square(g)
  108.             p_t = p - lr_t * m_t / (sqrt(v_t) + self.epsilon)
  109.  
  110.             self.updates.append((m, m_t))
  111.             self.updates.append((v, v_t))
  112.  
  113.             new_p = p_t
  114.  
  115.             self.updates.append((p, new_p))
  116.         return self.updates
  117.  
  118.     def get_config(self):
  119.         config = {'lr': float(self.lr.get_value()),
  120.                   'beta_1': float(self.beta_1.get_value(self.beta_1)),
  121.                   'beta_2': float(self.beta_2.get_value(self.beta_2)),
  122.                   'epsilon': self.epsilon}
  123.         base_config = super(Adam, self).get_config()
  124.         return dict(list(base_config.items()) + list(config.items()))
  125.    
  126. def sqrt(x):
  127.     x = T.clip(x, 0., np.inf)
  128.     return T.sqrt(x)
Add Comment
Please, Sign In to add comment