Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- epochs = 60
- def loss(actual, predicted):
- crossentropy_loss = tf.losses.categorical_crossentropy(actual, predicted)
- average_loss = tf.reduce_mean(crossentropy_loss)
- return average_loss
- def train(train_dataset, validation_dataset, epochs, learning_rate=1e-1, momentum=9e-1, decay=1e-6):
- optimizer = tf.optimizers.SGD(learning_rate=learning_rate=,
- momentum=momentum,
- decay=decay)
- for epoch in range(epochs):
- epoch_loss = []
- train_accuracy = []
- validation_accuracy = []
- for train_batch, validation_batch in zip(train_dataset, validation_dataset):
- train_batch_features, train_batch_labels = train_batch
- validation_batch_features, validation_batch_labels = validation_batch
- with tf.GradientTape() as tape:
- predictions = model(train_batch_features)
- train_loss = loss(train_batch_labels, predictions)
- gradients = tape.gradient(train_loss, model.trainable_variables)
- optimizer.apply_gradients(zip(gradients, model.trainable_variables))
- epoch_loss.append(train_loss)
- accuracy = tf.metrics.Accuracy()
- accuracy(tf.argmax(train_batch_labels, 1),
- tf.argmax(predictions, 1))
- train_accuracy.append(accuracy.result())
- validation_predictions = model(validation_batch_features)
- accuracy = tf.metrics.Accuracy()
- accuracy(tf.argmax(validation_batch_labels, 1),
- tf.argmax(validation_predictions, 1))
- validation_accuracy.append(accuracy.result())
- epoch_loss = tf.reduce_mean(epoch_loss)
- train_accuracy = tf.reduce_mean(train_accuracy)
- validation_accuracy = tf.reduce_mean(validation_accuracy)
- print('Epoch {} / {} : train loss = {}, train accuracy = {}, validation accuracy = {}'.format(epoch_loss,
- train_accuracy,
- validation_accuracy))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement