SHARE
TWEET

Untitled

a guest Jul 19th, 2019 69 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Jul  7 15:02:20 2019
  4. """
  5.  
  6. import pandas as pd
  7. import numpy as np
  8. import tensorflow as tf
  9. import matplotlib.pyplot as plt
  10.  
  11. # ---------------
  12. # Prepation of Data
  13. # ---------------
  14.  
  15. # Import in the required files
  16. xexxar = pd.read_csv("Xexxar.csv")
  17.  
  18. # We want to analyze Speed and Speed Rating since they are likely to be
  19. # linearly related.
  20. # It is grouped by Speed (Index) and Aggregated with .mean()
  21. xexxar_filtered = xexxar[['Speed', 'Speed Rating']].groupby("Speed").mean()
  22.  
  23. # Extract x and y (x in nested in the index)
  24. data_x = np.array(xexxar_filtered.index.tolist(), dtype = 'float')
  25. data_y = xexxar_filtered['Speed Rating']
  26.  
  27. # ---------------
  28. # TensorFlow Preparation
  29. # ---------------
  30.  
  31. # Define Feeding placeholders, placeholders are slots where data is fed during
  32. # the training process.
  33. feed_x = tf.placeholder(tf.float32)
  34. feed_y = tf.placeholder(tf.float32)
  35.  
  36. # We define the model as a linear relationship where, y = f(x) = w * x
  37. def model(x, w):
  38.     return(tf.multiply(x, w))
  39.    
  40. # Initialize w as 1.0, named weight
  41. var_w = tf.Variable(1.0, 'weight')
  42.    
  43. # Create model with placeholder x and weight variable
  44. model_y = model(feed_x, var_w)
  45.  
  46. # Define the cost. Square is required to make sure the value is always positive
  47. cost = tf.square(model_y - feed_y)
  48.  
  49. # We define the optimizer (ADAM) with learning rate = 0.001 and the cost.
  50. cost_op = tf.train.AdamOptimizer(0.001).minimize(cost)
  51.  
  52. # These are Tensorboard writers
  53.  
  54. # This defines var_w (weight) to be tracked
  55. summary_w = tf.summary.scalar("Weight", var_w)
  56. # This defines the Tensorboard summary writer
  57. summary_writer = tf.summary.FileWriter("logs_speed")
  58. # This merges all summaries (Not required since there's only 1 summary
  59. # excluding summary_writer)
  60. summary_merged = tf.summary.merge_all()
  61.  
  62. # Automatically initalize all variables
  63. init = tf.global_variables_initializer()
  64.  
  65. # Define number of epochs
  66. epochs = 50
  67.  
  68. # ---------------
  69. # TensorFlow Session
  70. # ---------------
  71.  
  72. with tf.Session() as sess:
  73.     # Initialize all variables
  74.     sess.run(init)
  75.    
  76.     # Step is used for Tensorboard Summary
  77.     step = 0
  78.     for epoch in range(epochs):
  79.         for x, y in zip(data_x, data_y):
  80.             step += 1
  81.            
  82.             # We extract summary_str from the sess.run
  83.             # summary_merged outputs to summary_str
  84.             # cost_op outputs to _ (placeholder)
  85.             summary_str, _ = sess.run([summary_merged, cost_op],
  86.                                       feed_dict={
  87.                                               feed_x: x,
  88.                                               feed_y: y
  89.                                               })
  90.            
  91.             # Append to Tensorboard Summary
  92.             summary_writer.add_summary(summary_str, step)
  93.        
  94.         print("Epoch {} \t Weight {}".format(epoch, var_w.eval(sess)))
  95.        
  96.     # Extract calculated w
  97.     trained_w = var_w.eval(sess)
  98.  
  99. # ---------------
  100. # Evaluate using matplotlib.pyplot
  101. # ---------------
  102.  
  103. # Get new trained_y with new trained_y
  104. trained_y = trained_w * data_x
  105.  
  106. # Plot original
  107. plt.scatter(data_x, data_y)
  108. # Plot new
  109. plt.scatter(data_x, trained_y, c = 'r')
  110.  
  111. # Show
  112. plt.show()
  113.  
  114. # Clears Graph when sessions don't close properly sometimes
  115. tf.reset_default_graph()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top