Advertisement
Guest User

Untitled

a guest
Jun 28th, 2017
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.56 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import data_helper
  4.  
  5. class TLU(object):
  6. """
  7. Thresholding Logic Unit class definition
  8. """
  9. epoch = None
  10. graph = None
  11. weight = np.asarray(
  12. [[1.0],
  13. [-1.0]]
  14. )
  15.  
  16. def __init__(self, epoch=5000):
  17. self.epoch = epoch
  18. self.graph = tf.Graph()
  19. with self.graph.as_default():
  20. # Define input and weight
  21. self.weight = tf.Variable(self.weight, dtype=tf.float16)
  22. self.weight = tf.Variable(tf.random_normal([2, 1], mean=0.5, stddev=0.1, dtype=tf.float16))
  23. self._x = tf.placeholder(tf.float16, [None, 2], name='input')
  24. self._y = tf.placeholder(tf.float16, [None, 1], name='output')
  25.  
  26. # multiplication and linear shiftting
  27. mul = tf.matmul(self._x, self.weight)
  28. shift = 10 * mul
  29.  
  30. # Activation function
  31. self._output = tf.nn.sigmoid(shift, name='output')
  32.  
  33. # Loss function
  34. self.loss = tf.reduce_sum( 0.5 * tf.square(self._y - self._output))
  35.  
  36. # Optimizer
  37. self.optimizer = tf.train.AdamOptimizer(0.01).minimize(self.loss)
  38.  
  39. def getGraph(self):
  40. """
  41. Return the graph object
  42. Ret: The graph object
  43. """
  44. return self.graph
  45.  
  46. def fit(self, x, y, early_stopping_value=0.0001):
  47. """
  48. Train the TLU model for specific training epoch
  49. """
  50. for i in range(self.epoch):
  51. feed_dict = {
  52. self._x: x,
  53. self._y: y
  54. }
  55. _weight = sess.run(self.weight)
  56. _loss, _ = sess.run([self.loss, self.optimizer], feed_dict=feed_dict)
  57.  
  58. if i % 500 == 0:
  59. print "epoch: ", i, " weight: ", np.reshape(_weight, [-1]), "\tloss: ", _loss
  60. if _loss.sum() < early_stopping_value:
  61. break
  62.  
  63. def predict(self, x):
  64. """
  65. Predict the result
  66. """
  67. return sess.run([self._output,], feed_dict={self._x: x})
  68.  
  69. if __name__ == '__main__':
  70. # Load data and build model
  71. train_x, train_y, test_x = data_helper.load()
  72. cell = TLU()
  73. graph = cell.getGraph()
  74.  
  75. with tf.Session(graph=graph) as sess:
  76. sess.run(tf.global_variables_initializer())
  77.  
  78. # Train
  79. print "<< trainning >>"
  80. cell.fit(train_x, train_y)
  81.  
  82. # Predict
  83. print "<< testing >>"
  84. result = cell.predict(test_x)
  85. for i in range(len(test_x)):
  86. print "test index: ", i, '\tresult: ', result[0][i]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement