Advertisement
Guest User

Untitled

a guest
Jan 23rd, 2017
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.03 KB | None | 0 0
  1. """
  2. Code for training RBMs with contrastive divergence. Tries to be as
  3. quick and memory-efficient as possible while utilizing only pure Python
  4. and NumPy.
  5. """
  6.  
  7. # Copyright (c) 2009, David Warde-Farley
  8. # All rights reserved.
  9. #
  10. # Redistribution and use in source and binary forms, with or without
  11. # modification, are permitted provided that the following conditions
  12. # are met:
  13. # 1. Redistributions of source code must retain the above copyright
  14. # notice, this list of conditions and the following disclaimer.
  15. # 2. Redistributions in binary form must reproduce the above copyright
  16. # notice, this list of conditions and the following disclaimer in the
  17. # documentation and/or other materials provided with the distribution.
  18. # 3. The name of the author may not be used to endorse or promote products
  19. # derived from this software without specific prior written permission.
  20. #
  21. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  22. # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  23. # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  24. # IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  25. # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  26. # NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  30. # THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31.  
  32. import sys
  33. import time
  34.  
  35. import numpy as np
  36.  
  37. class RBM(object):
  38. """
  39. Class representing a basic restricted Boltzmann machine, with
  40. binary stochastic visible units and binary stochastic hidden
  41. units.
  42. """
  43. def __init__(self, nvis, nhid, mfvis=True, mfhid=False, initvar=0.1):
  44. nweights = nvis * nhid
  45. vb_offset = nweights
  46. hb_offset = nweights + nvis
  47.  
  48. # One parameter matrix, with views onto it specified below.
  49. self.params = np.empty((nweights + nvis + nhid))
  50.  
  51. # Weights between the hiddens and visibles
  52. self.weights = self.params[:vb_offset].reshape(nvis, nhid)
  53.  
  54. # Biases on the visible units
  55. self.visbias = self.params[vb_offset:hb_offset]
  56.  
  57. # Biases on the hidden units
  58. self.hidbias = self.params[hb_offset:]
  59.  
  60. # Attributes for scratch arrays used during sampling.
  61. self._hid_states = None
  62. self._vis_states = None
  63.  
  64. # Instance-specific mean field settings.
  65. self._mfvis = mfvis
  66. self._mfhid = mfhid
  67.  
  68. @property
  69. def numvis(self):
  70. """The number of visible units (i.e. dimension of the input)."""
  71. return self.visbias.shape[0]
  72.  
  73. @property
  74. def numhid(self):
  75. """The number of hidden units in this model."""
  76. return self.hidbias.shape[0]
  77.  
  78. def _prepare_buffer(self, ncases, kind):
  79. """
  80. Prepare the _hid_states and _vis_states buffers for
  81. use for a minibatch of size `ncases`, reshaping or
  82. reallocating as necessary. `kind` is one of 'hid', 'vis'.
  83. """
  84. if kind not in ['hid', 'vis']:
  85. raise ValueError('kind argument must be hid or vis')
  86. name = '_%s_states' % kind
  87. num = getattr(self, 'num%s' % kind)
  88. buf = getattr(self, name)
  89. if buf is None or buf.shape[0] < ncases:
  90. if buf is not None:
  91. del buf
  92. buf = np.empty((ncases, num))
  93. setattr(self, name, buf)
  94. buf[...] = np.NaN
  95. return buf[:ncases]
  96.  
  97. def hid_activate(self, input, mf=False):
  98. """
  99. Activate the hidden units by sampling from their conditional
  100. distribution given each of the rows of `inputs. If `mf` is True,
  101. return the deterministic, real-valued probabilities of activation
  102. in place of stochastic binary samples ('mean-field').
  103. """
  104. input = np.atleast_2d(input)
  105. ncases, ndim = input.shape
  106. hid = self._prepare_buffer(ncases, 'hid')
  107. self._update_hidden(input, hid, mf)
  108. return hid
  109.  
  110. def _update_hidden(self, vis, hid, mf=False):
  111. """
  112. Update hidden units by writing new values to array `hid`.
  113.  
  114. If `mf` is False, hidden unit values are sampled from their
  115. conditional distribution given the visible unit configurations
  116. specified in each row of `vis`. If `mf` is True, the
  117. deterministic, real-valued probabilities of activation are
  118. written instead of stochastic binary samples ('mean-field').
  119. """
  120. hid[...] = np.dot(vis, self.weights)
  121. hid[...] += self.hidbias
  122. hid *= -1.
  123. np.exp(hid, hid)
  124. hid += 1.
  125. hid **= -1.
  126. if not mf:
  127. self.sample_hid(hid)
  128.  
  129. def _update_visible(self, vis, hid, mf=False):
  130. """
  131. Update visible units by writing new values to array `hid`.
  132.  
  133. If `mf` is False, visible unit values are sampled from their
  134. conditional distribution given the hidden unit configurations
  135. specified in each row of `hid`. If `mf` is True, the
  136. deterministic, real-valued probabilities of activation are
  137. written instead of stochastic binary samples ('mean-field').
  138. """
  139.  
  140. # Implements 1/(1 + exp(-WX) with in-place operations
  141. vis[...] = np.dot(hid, self.weights.T)
  142. vis[...] += self.visbias
  143. vis *= -1.
  144. np.exp(vis, vis)
  145. vis += 1.
  146. vis **= -1.
  147. if not mf:
  148. self.sample_vis(vis)
  149.  
  150. @classmethod
  151. def binary_threshold(cls, probs):
  152. """
  153. Given a set of real-valued activation probabilities,
  154. sample binary values with the given Bernoulli parameter,
  155. and update the array in-placewith the Bernoulli samples.
  156. """
  157. samples = np.random.uniform(size=probs.shape)
  158.  
  159. # Simulate Bernoulli trials with p = probs[i,j] by generating random
  160. # uniform and counting any number less than probs[i,j] as success.
  161. probs[samples < probs] = 1.
  162.  
  163. # Anything not set to 1 should be 0 once floored.
  164. np.floor(probs, probs)
  165.  
  166. # Binary hidden units
  167. sample_hid = binary_threshold
  168.  
  169. # Binary visible units
  170. sample_vis = binary_threshold
  171.  
  172. def gibbs_walk(self, nsteps, hid):
  173. """
  174. Perform nsteps of alternating Gibbs sampling,
  175. sampling the hidden units in parallel followed by the
  176. visible units.
  177.  
  178. Depending on instantiation arguments, one or both sets of
  179. units may instead have "mean-field" activities computed.
  180. Mean-field is always used in lieu of sampling for the
  181. terminal hidden unit configuration.
  182. """
  183. hid = np.atleast_2d(hid)
  184. ncases = hid.shape[0]
  185.  
  186. # Allocate (or reuse) a buffer with which to store
  187. # the states of the visible units
  188. vis = self._prepare_buffer(ncases, 'vis')
  189.  
  190. for iter in xrange(nsteps):
  191.  
  192. # Update the visible units conditioning on the hidden units.
  193. self._update_visible(vis, hid, self._mfvis)
  194.  
  195. # Always do mean-field on the last hidden unit update to get a
  196. # less noisy estimate of the negative phase correlations.
  197. if iter < nsteps - 1:
  198. mfhid = self._mfhid
  199. else:
  200. mfhid = True
  201.  
  202. # Update the hidden units conditioning on the visible units.
  203. self._update_hidden(vis, hid, mfhid)
  204.  
  205. return self._vis_states[:ncases], self._hid_states[:ncases]
  206.  
  207. class GaussianBinaryRBM(RBM):
  208. def _update_visible(self, vis, hid, mf=False):
  209. vis[...] = np.dot(hid, self.weights.T)
  210. vis += self.visbias
  211. if not mf:
  212. self.sample_vis(vis)
  213.  
  214. @classmethod
  215. def sample_vis(self, vis):
  216. vis += np.random.normal(size=vis.shape)
  217.  
  218. class CDTrainer(object):
  219. """An object that trains a model using vanilla contrastive divergence."""
  220.  
  221. def __init__(self, model, weightcost=0.0002, rates=(1e-4, 1e-4, 1e-4),
  222. cachebatchsums=True):
  223. self._model = model
  224. self._visbias_rate, self._hidbias_rate, self._weight_rate = rates
  225. self._weightcost = weightcost
  226. self._cachebatchsums = cachebatchsums
  227. self._weightstep = np.zeros(model.weights.shape)
  228.  
  229. def train(self, data, epochs, cdsteps=1, minibatch=50, momentum=0.9):
  230. """
  231. Train an RBM with contrastive divergence, using `nsteps`
  232. steps of alternating Gibbs sampling to draw the negative phase
  233. samples.
  234. """
  235. data = np.atleast_2d(data)
  236. ncases, ndim = data.shape
  237. model = self._model
  238.  
  239. if self._cachebatchsums:
  240. batchsums = {}
  241.  
  242. for epoch in xrange(epochs):
  243.  
  244. # An epoch is a single pass through the training data.
  245.  
  246. epoch_start = time.clock()
  247.  
  248. # Mean squared error isn't really the right thing to measure
  249. # for RBMs with binary visible units, but gives a good enough
  250. # indication of whether things are moving in the right way.
  251.  
  252. mse = 0
  253.  
  254. # Compute the summed visible activities once
  255.  
  256. for offset in xrange(0, ncases, minibatch):
  257.  
  258. # Select a minibatch of data.
  259. batch = data[offset:(offset+minibatch)]
  260.  
  261. batchsize = batch.shape[0]
  262.  
  263. # Mean field pass on the hidden units f
  264. hid = model.hid_activate(batch, mf=True)
  265.  
  266. # Correlations between the data and the hidden unit activations
  267. poscorr = np.dot(batch.T, hid)
  268.  
  269. # Activities of the hidden units
  270. posact = hid.sum(axis=0)
  271.  
  272. # Threshold the hidden units so that they can't convey
  273. # more than 1 bit of information in the subsequent
  274. # sampling (assuming the hidden units are binary,
  275. # which they most often are).
  276. model.sample_hid(hid)
  277.  
  278. # Simulate Gibbs sampling for a given number of steps.
  279. vis, hid = model.gibbs_walk(cdsteps, hid)
  280.  
  281. # Update the weights with the difference in correlations
  282. # between the positive and negative phases.
  283.  
  284. thisweightstep = poscorr
  285. thisweightstep -= np.dot(vis.T, hid)
  286. thisweightstep /= batchsize
  287. thisweightstep -= self._weightcost * model.weights
  288. thisweightstep *= self._weight_rate
  289.  
  290. self._weightstep *= momentum
  291. self._weightstep += thisweightstep
  292.  
  293. model.weights += self._weightstep
  294.  
  295. # The gradient of the visible biases is the difference in
  296. # summed visible activities for the minibatch.
  297. if self._cachebatchsums:
  298. if offset not in batchsums:
  299. batchsum = batch.sum(axis=0)
  300. batchsums[offset] = batchsum
  301. else:
  302. batchsum = batchsums[offset]
  303. else:
  304. batchsum = batch.sum(axis=0)
  305.  
  306. visbias_step = batchsum - vis.sum(axis=0)
  307. visbias_step *= self._visbias_rate / batchsize
  308.  
  309. model.visbias += visbias_step
  310.  
  311. # The gradient of the hidden biases is the difference in
  312. # summed hidden activities for the minibatch.
  313.  
  314. hidbias_step = posact - hid.sum(axis=0)
  315. hidbias_step *= self._hidbias_rate / batchsize
  316.  
  317. model.hidbias += hidbias_step
  318.  
  319. # Compute the squared error in-place.
  320. vis -= batch
  321. vis **= 2.
  322.  
  323. # Add to the total epoch estimate.
  324. mse += vis.sum() / ncases
  325.  
  326. print "Done epoch %d: %f seconds, MSE=%f" % \
  327. (epoch + 1, time.clock() - epoch_start, mse)
  328. sys.stdout.flush()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement