Advertisement
Guest User

Untitled

a guest
May 13th, 2016
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.50 KB | None | 0 0
  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3.  
  4. from multiprocessing import Process, Queue
  5.  
  6. import six
  7. import numpy as np
  8. import tensorflow as tf
  9.  
  10. import arimo
  11.  
  12.  
  13. class DDFDataFetcher(Process):
  14.  
  15. def __init__(self):
  16. super(DDFDataFetcher, self).__init__()
  17.  
  18. def next_batch(self):
  19. return None
  20.  
  21.  
  22. class DDFRandomDataFetcher(DDFDataFetcher):
  23. def __init__(self, queue, server='', port=0,
  24. username='', password='', ddf_uri='',
  25. sample_size=20, batch_size=5, df_to_np_fn=None):
  26. super(DDFRandomDataFetcher, self).__init__()
  27. self.queue = queue
  28. self.server = server
  29. self.port = port
  30. self.username = username
  31. self.password = password
  32. self.ddf_uri = ddf_uri
  33. self.sample_size = sample_size
  34. self.batch_size = batch_size
  35. self.df_to_np_fn = df_to_np_fn
  36.  
  37. if self.df_to_np_fn is None:
  38. def p(df):
  39. # take the last column to be the label
  40. return df.iloc[:, :-1].values, df.iloc[:, -1].values
  41.  
  42. self.df_to_np_fn = p
  43.  
  44. def run(self):
  45. s = arimo.connect(self.server, self.port, self.username, self.password)
  46. ddf = s.get_ddf(self.ddf_uri)
  47. frac = self.sample_size / float(len(ddf))
  48. remain_inp, remain_target = None, None
  49.  
  50. while True:
  51. # df = ddf.sample(size=self.sample_size, replace=True)
  52. # sample() doesn't work on test-pe because of new PR merged
  53. ddfs = ddf.sample2ddf(fraction=frac, replace=True)
  54. inp, target = self.df_to_np_fn(ddfs.head(len(ddfs)))
  55.  
  56. if remain_inp is not None:
  57. inp = np.vstack([remain_inp, inp])
  58. if remain_target is not None:
  59. target = np.vstack([remain_target, target])
  60.  
  61. assert inp.shape[0] == target.shape[0]
  62.  
  63. for i in range(0, inp.shape[0], self.batch_size):
  64. # will block if queue is full
  65. self.queue.put((inp[i:(i + self.batch_size), :], target[i:(i + self.batch_size), :]))
  66.  
  67. if inp.shape[0] % self.batch_size != 0:
  68. idx = (inp.shape[0] / self.batch_size) * self.batch_size
  69. remain_inp = inp[idx:, :]
  70. remain_target = target[idx:, :]
  71.  
  72. def next_batch(self):
  73. return self.queue.get()
  74.  
  75.  
  76. class Trainer(object):
  77.  
  78. def __init__(self, fetcher, input_size, n_classes):
  79. self.fetcher = fetcher
  80.  
  81. self.x = tf.placeholder(tf.float32, [None, input_size])
  82. W = tf.Variable(tf.zeros([input_size, n_classes]))
  83. b = tf.Variable(tf.zeros([n_classes]))
  84. self.y = tf.nn.softmax(tf.matmul(self.x, W) + b)
  85.  
  86. # Define loss and optimizer
  87. self.y_ = tf.placeholder(tf.float32, [None, n_classes])
  88. cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.y_ * tf.log(self.y), reduction_indices=[1]))
  89. self.train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  90.  
  91. def train(self, iterations):
  92. tf.initialize_all_variables().run()
  93. self.fetcher.start()
  94. for i in range(iterations):
  95. batch_xs, batch_ys = self.fetcher.next_batch()
  96. self.train_step.run({self.x: batch_xs, self.y_: batch_ys})
  97. if i % 2 == 0:
  98. print('Iteration {}'.format(i))
  99.  
  100. correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(self.y_, 1))
  101. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  102. batch_xs, batch_ys = self.fetcher.next_batch()
  103. print('Accuracy on random sample: {}'.format(accuracy.eval({self.x: batch_xs, self.y_: batch_ys})))
  104.  
  105.  
  106. if __name__ == '__main__':
  107. try:
  108. input = raw_input
  109. except NameError:
  110. pass
  111.  
  112. q = Queue(20)
  113. server = six.text_type(input('Server address: '))
  114. port = int(six.text_type(input('Server port: ')))
  115. username = six.text_type(input('User name: '))
  116. passwd = six.text_type(input('Password: '))
  117.  
  118. ddf_uri = 'ddf://adatao/mtcars_tf'
  119.  
  120. def get_batch(df):
  121. inp = df[['mpg', 'cyl', 'disp', 'hp', 'drat', 'wt', 'qesc', 'vs', 'gear', 'carb']].values
  122. target_idx = df['am'].values
  123. n = target_idx.shape[0]
  124. b = np.zeros((n, 2))
  125. b[np.arange(n), target_idx] = 1
  126. return inp, b
  127.  
  128. ddf_fetcher = DDFRandomDataFetcher(q, server, port, username, passwd, ddf_uri=ddf_uri,
  129. df_to_np_fn=get_batch)
  130. sess = tf.InteractiveSession()
  131. trainer = Trainer(ddf_fetcher, 10, 2)
  132. trainer.train(20)
  133. ddf_fetcher.terminate()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement