Advertisement
Guest User

DBNSampler.py

a guest
Jul 9th, 2013
193
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.75 KB | None | 0 0
  1. from pylearn2.models.dbm import DBM
  2. from theano.sandbox.rng_mrg import MRG_RandomStreams
  3. from pylearn2.base import Block
  4.  
  5. class DBNSampler(Block):
  6.     """
  7.    A Block used to sample from the last layer of a list of pre-trained RBMs.
  8.    Here, we assume the RBM is an instance of the DBM class with only one hidden layer.
  9.  
  10.    It does so by computing the expected activation of the N-th RBM's hidden
  11.    layer given the state of its visible layer, and use that expected activation
  12.    as the state of the N+1-th RBM's visible layer.
  13.    """
  14.     def __init__(self, rbm_list):
  15.         super(DBNSampler, self).__init__()
  16.         self.theano_rng = MRG_RandomStreams(2012 + 10 + 14)
  17.         self.rbm_list = rbm_list
  18.  
  19.     def __call__(self, inputs):
  20.         visible_state = inputs
  21.         for rbm in self.rbm_list:
  22.             # What the hidden layer sees from the visible layer
  23.             visible_state = rbm.visible_layer.upward_state(visible_state)
  24.             # The hidden layer's expected activation
  25.             total_state = rbm.hidden_layers[0].mf_update(visible_state, None)
  26.             # The expected activation is used as the next visible layer's state
  27.             visible_state = rbm.hidden_layers[0].upward_state(total_state)
  28.  
  29.         # This is the last layer's expected activation
  30.         expected_activation = visible_state
  31.  
  32.         rval = self.theano_rng.binomial(size=expected_activation.shape,
  33.                                         p=expected_activation,
  34.                                         dtype=expected_activation.dtype, n=1)
  35.         return rval
  36.  
  37.     def get_input_space(self):
  38.         return self.rbm_list[-1].visible_layer.space
  39.  
  40.     def get_output_space(self):
  41.         return self.rbm_list[-1].hidden_layers[-1].get_output_space()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement