Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import numpy as np
- import data_helper
- class TLU(object):
- """
- Thresholding Logic Unit class definition
- """
- epoch = None
- graph = None
- weight = np.asarray(
- [[1.0],
- [-1.0]]
- )
- def __init__(self, epoch=5000):
- self.epoch = epoch
- self.graph = tf.Graph()
- with self.graph.as_default():
- # Define input and weight
- self.weight = tf.Variable(self.weight, dtype=tf.float16)
- self.weight = tf.Variable(tf.random_normal([2, 1], mean=0.5, stddev=0.1, dtype=tf.float16))
- self._x = tf.placeholder(tf.float16, [None, 2], name='input')
- self._y = tf.placeholder(tf.float16, [None, 1], name='output')
- # multiplication and linear shiftting
- mul = tf.matmul(self._x, self.weight)
- shift = 10 * mul
- # Activation function
- self._output = tf.nn.sigmoid(shift, name='output')
- # Loss function
- self.loss = tf.reduce_sum( 0.5 * tf.square(self._y - self._output))
- # Optimizer
- self.optimizer = tf.train.AdamOptimizer(0.01).minimize(self.loss)
- def getGraph(self):
- """
- Return the graph object
- Ret: The graph object
- """
- return self.graph
- def fit(self, x, y, early_stopping_value=0.0001):
- """
- Train the TLU model for specific training epoch
- """
- for i in range(self.epoch):
- feed_dict = {
- self._x: x,
- self._y: y
- }
- _weight = sess.run(self.weight)
- _loss, _ = sess.run([self.loss, self.optimizer], feed_dict=feed_dict)
- if i % 500 == 0:
- print "epoch: ", i, " weight: ", np.reshape(_weight, [-1]), "\tloss: ", _loss
- if _loss.sum() < early_stopping_value:
- break
- def predict(self, x):
- """
- Predict the result
- """
- return sess.run([self._output,], feed_dict={self._x: x})
- if __name__ == '__main__':
- # Load data and build model
- train_x, train_y, test_x = data_helper.load()
- cell = TLU()
- graph = cell.getGraph()
- with tf.Session(graph=graph) as sess:
- sess.run(tf.global_variables_initializer())
- # Train
- print "<< trainning >>"
- cell.fit(train_x, train_y)
- # Predict
- print "<< testing >>"
- result = cell.predict(test_x)
- for i in range(len(test_x)):
- print "test index: ", i, '\tresult: ', result[0][i]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement