Advertisement
Guest User

resnet50_cifar10_rocm_benchmark

a guest
Nov 1st, 2021
3,478
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 0.96 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. gpus = tf.config.experimental.list_physical_devices('GPU')
  4. if gpus:
  5.   try:
  6.     for gpu in gpus:
  7.       tf.config.experimental.set_memory_growth(gpu, True)
  8.   except RuntimeError as e:
  9.     print(e)
  10.  
  11. bsize = 500
  12. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  13. x_train = (x_train/255).astype('float32')
  14. x_test = (x_test/255).astype('float32')
  15. n_classes = np.max(y_train)+1
  16. train_dset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(bsize)
  17. test_dset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(bsize)
  18.  
  19. mod = tf.keras.applications.resnet_v2.ResNet50V2(weights=None,
  20.     input_shape = x_train.shape[1:],
  21.     include_top = True,
  22.     classes = n_classes,
  23.     classifier_activation='softmax')
  24. mod.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  25. #import pdb; pdb.set_trace()
  26. mod.fit(train_dset, validation_data = test_dset, epochs = 8)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement