Guest User

Untitled

a guest
Dec 14th, 2017
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.81 KB | None | 0 0
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4.  
  5. tf.reset_default_graph()
  6.  
  7. samples = 1000
  8. times = [1e-2*i for i in range(samples+1)]
  9. sin = np.sin(times[:-1])
  10. sin_next = np.sin(times[1:])
  11.  
  12. time_step = 10
  13. sin = np.reshape(sin, [-1, time_step, 1])
  14. sin_next = np.reshape(sin_next, [-1, 1])
  15.  
  16. signal = tf.placeholder(tf.float32,
  17. shape=[None, time_step, 1])
  18. signal_next = tf.placeholder(tf.float32,
  19. [None, 1])
  20.  
  21. unstacked_signal = tf.unstack(signal, axis=1)
  22. # print('inputs: ')
  23. # for t in unstacked_signal:
  24. # print(t)
  25.  
  26. state_size = 30
  27. rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=state_size)
  28. states, state = tf.nn.static_rnn(cell=rnn_cell,
  29. inputs=unstacked_signal,
  30. dtype=tf.float32)
  31.  
  32. # print('outputs: ')
  33. # for s in states:
  34. # print(s)
  35. # print(state.shape)
  36.  
  37. states = tf.stack(states, axis=1)
  38. reshaped_states = tf.reshape(states, [-1, state_size])
  39. print(reshaped_states.shape)
  40.  
  41. output = tf.layers.dense(reshaped_states, 1, use_bias=False)
  42. print(output.shape)
  43.  
  44. loss = tf.losses.mean_squared_error(signal_next, output)
  45. train_op = tf.train.GradientDescentOptimizer(1e-2).minimize(loss)
  46.  
  47. accuracy = tf.contrib.metrics.streaming_pearson_correlation(output, signal_next)
  48.  
  49. with tf.Session() as sess:
  50. sess.run(tf.global_variables_initializer())
  51. sess.run(tf.local_variables_initializer())
  52. for i in range(3000):
  53. _, _loss, _acc = sess.run([train_op, loss, accuracy],
  54. feed_dict={signal: sin, signal_next: sin_next})
  55. if i%100 == 0:
  56. print('step: {}, loss: {}, acc: {}'.format(i, _loss, _acc[0]))
  57.  
  58. _pred = sess.run(output, feed_dict={signal: sin})
  59. plt.figure(1)
  60. plt.plot(sin_next)
  61. plt.figure(2)
  62. plt.plot(_pred)
  63. plt.show()
Add Comment
Please, Sign In to add comment