Advertisement
Jim421616

Regression test for checking

Oct 21st, 2020 (edited)
2,229
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.36 KB | None | 0 0
  1. import seaborn as sns
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import pandas as pd
  7. import tensorflow_docs as tfdocs
  8. import tensorflow_docs.plots
  9. import tensorflow_docs.modeling
  10. from scipy.stats import gaussian_kde
  11.  
  12. ################################################################################
  13.    
  14. def build_model(n):
  15. # Define the prediction model. The NN takes 5 input features (magnitudes), and outputs
  16. # the redshift.
  17.   model = keras.Sequential([
  18.     keras.layers.Dense(5, activation='relu', input_shape=(n,)), # 5 inputs (mags)
  19.     keras.layers.Dense(4, activation='relu'), # do I need this layer?
  20.     keras.layers.Dense(1) # 1 output (z)
  21.     ])
  22.   model.compile(loss='mae',
  23.                 optimizer = tf.keras.optimizers.RMSprop(0.001),
  24.                 metrics = ['mae', 'mse'])
  25.   return model
  26. ################################################################################
  27. # These are convenience functions for plotting the results of the model's predictions.
  28. def plot_mae():
  29.     plotter = tfdocs.plots.HistoryPlotter(smoothing_std=2)
  30.     plotter.plot({'Basic': history}, metric = "mae")
  31.     plt.title('Mean Absolute Error evolution for %s'%datasetname)
  32.     plt.ylabel(r'$\Delta z$')
  33.     plt.show()
  34.  
  35. def plot_mse():
  36.     plotter = tfdocs.plots.HistoryPlotter(smoothing_std=2)
  37.     plotter.plot({'Basic': history}, metric = "mse")
  38.     plt.title('Mean Standard Error evolution for %s'%datasetname)
  39.     plt.ylabel(r'$\Delta z$')
  40.     plt.show()
  41.  
  42. def plot_z():
  43.     x, y = valid_dataset['z'], valid_dataset['Predicted z']
  44.     xy = np.vstack([x,y])
  45.     z = gaussian_kde(xy)(xy)
  46.     # Sort the points by density, so that the densest points are plotted last
  47.     # idx = z.argsort()
  48.     # x, y, z = x[idx], y[idx], z[idx]
  49.     plt.scatter(valid_dataset['z'], valid_dataset['Predicted z'],
  50.                 s = 50,
  51.                 alpha = 0.5,
  52.                 marker = '.',
  53.                 # edgecolor = '',
  54.                 c = z)
  55.     plt.colorbar()
  56.     plt.title('Redshift predictions for %s\ncompared to spectroscopic redshift'%datasetname)
  57.     plt.xlabel(r'$z_{spec}$')
  58.     plt.ylabel('Predicted z')
  59.     plt.show()
  60.  
  61. def plot_deltaz():
  62.     valid_dataset['Delta z'] = valid_dataset['Predicted z'] - valid_dataset['z']
  63.     plt.scatter(valid_dataset['z'], valid_dataset['Delta z'],
  64.                 s = 2, alpha = 0.5, marker = '.')
  65.     plt.title('Deviation of redshift predictions for %s\nfrom spectroscopic redshift'%datasetname)
  66.     plt.xlabel(r'$z_{spec}$')
  67.     plt.ylabel(r'$\Delta z$')
  68.     # plt.xlim([0, 1.5])
  69.     plt.show()
  70.  
  71. def plot_z_boxplot(outliers = False):
  72.     columns = [valid_dataset['Predicted z'], valid_dataset['z']]
  73.     df = pd.DataFrame(data = valid_dataset,
  74.                 columns = columns)
  75.     sns.boxplot(data = pd.melt(df),
  76.         x = None, y = None, # This screws it up
  77.         linewidth = 0.5,
  78.         flierprops = dict(markerfacecolor = '0.1', markersize = 0.2),
  79.         showfliers = outliers
  80.         )
  81.     # valid_dataset.boxplot(column = ['Predicted z', 'z'])
  82.     plt.title('Distribution of statistical parameters\nfor %s'%datasetname)
  83.     plt.tight_layout()
  84.     plt.show()
  85.  
  86. def plot_delta_z_hist():
  87.     stats = valid_dataset['Delta z'].describe()
  88.     valid_dataset['Delta z'] = valid_dataset['Predicted z'] - valid_dataset['z']
  89.     valid_dataset['Delta z'].hist(label = ' mean = %.3f\n std dev = %.3f'%
  90.                                   (stats[1], stats[2]),
  91.                                   bins = 100)
  92.     plt.title(r'Distribution of $\Delta z$ for %s'%datasetname)
  93.     plt.legend()
  94.     plt.xlabel(r'$\Delta z$')
  95.     plt.ylabel('Count')
  96.     plt.show()
  97. ################################################################################
  98. dataset = pd.read_csv('sdss12.csv')
  99. datasetname = 'SDSS DR12 QSOs'
  100. mags = ['umag', 'gmag', 'rmag', 'imag', 'zmag']
  101. columns = mags + ['z']
  102.  
  103. num_features = len(columns)
  104.  
  105. dataset = dataset[columns]      
  106. train_dataset = dataset.sample(frac = 0.8, random_state = 1)
  107. valid_dataset = dataset.drop(train_dataset.index)    
  108. train_labels = train_dataset.iloc[:, len(mags):].values # all but the last column
  109. valid_labels = valid_dataset.iloc[:, len(mags):].values # all but the last column
  110.  
  111. print('Training set: \n',train_dataset)
  112. print('Validation set: \n', valid_dataset)
  113.  
  114. model = build_model(num_features)
  115. model.summary()
  116.    
  117. N = 50
  118. early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=500)
  119. history = model.fit(train_dataset.iloc[:,:num_features], train_labels, epochs = N,
  120.                     validation_split = 0.2, verbose = 0,
  121.                     callbacks = [
  122.                         early_stop,
  123.                         tfdocs.modeling.EpochDots()
  124.                         ])
  125. model.save("model2.h5")
  126.  
  127. # Model testing
  128. valid_predictions = model.predict(valid_dataset.iloc[:,:num_features])
  129. valid_dataset['Predicted z'] = valid_predictions
  130. print("\nPredicted\n")
  131. print(valid_dataset)
  132.  
  133. # Visualise the model's training progress using the stats in the history object
  134. hist = pd.DataFrame(history.history)
  135. hist['epoch'] = history.epoch
  136. print(hist.tail)
  137.  
  138. print('Predicted z stats: \n', valid_dataset['Predicted z'].describe())
  139. print('Spectroscopic z stats: \n', valid_dataset['z'].describe())
  140.  
  141. plot_mse()
  142. plot_mae()
  143. plot_z()
  144. plot_deltaz()
  145. plot_delta_z_hist()
  146. plot_z_boxplot(True)
  147. plot_z_boxplot(False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement