Advertisement
JokoEliyanto

Untitled

Apr 4th, 2020
433
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.91 KB | None | 0 0
  1. #@title Define the functions that build and train a model
  2. def build_model(my_learning_rate):
  3.   """Create and compile a simple linear regression model."""
  4.   # Most simple tf.keras models are sequential.
  5.   # A sequential model contains one or more layers.
  6.   model = tf.keras.models.Sequential()
  7.  
  8.   # Describe the topography of the model.
  9.   # The topography of a simple linear regression model
  10.   # is a single node in a single layer.
  11.   model.add(tf.keras.layers.Dense(units=1,
  12.                                   input_shape=(1,)))
  13.  
  14.   # Compile the model topography into code that
  15.   # TensorFlow can efficiently execute. Configure
  16.   # training to minimize the model's mean squared error.
  17.   model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=my_learning_rate),
  18.                 loss="mean_squared_error",
  19.                 metrics=[tf.keras.metrics.RootMeanSquaredError()])
  20.  
  21.   return model          
  22.  
  23.  
  24. def train_model(model, feature, label, epochs, batch_size):
  25.   """Train the model by feeding it data."""
  26.  
  27.   # Feed the feature values and the label values to the
  28.   # model. The model will train for the specified number
  29.   # of epochs, gradually learning how the feature values
  30.   # relate to the label values.
  31.   history = model.fit(x=feature,
  32.                       y=label,
  33.                       batch_size=None,
  34.                       epochs=epochs)
  35.  
  36.   # Gather the trained model's weight and bias.
  37.   trained_weight = model.get_weights()[0]
  38.   trained_bias = model.get_weights()[1]
  39.  
  40.   # The list of epochs is stored separately from the
  41.   # rest of history.
  42.   epochs = history.epoch
  43.  
  44.   # Gather the history (a snapshot) of each epoch.
  45.   hist = pd.DataFrame(history.history)
  46.  
  47.   # Specifically gather the model's root mean
  48.   #squared error at each epoch.
  49.   rmse = hist["root_mean_squared_error"]
  50.  
  51.   return trained_weight, trained_bias, epochs, rmse
  52.  
  53. print("Defined create_model and train_model")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement