Advertisement
Guest User

bert_imdb_rocm_benchmark

a guest
Nov 1st, 2021
3,486
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.06 KB | None | 0 0
  1. #from https://huggingface.co/transformers/training.html
  2. import tensorflow as tf
  3. import numpy as np
  4. gpus = tf.config.experimental.list_physical_devices('GPU')
  5. if gpus:
  6.   try:
  7.     for gpu in gpus:
  8.       tf.config.experimental.set_memory_growth(gpu, True)
  9.   except RuntimeError as e:
  10.     print(e)
  11.  
  12. from datasets import load_dataset
  13. from transformers import AutoTokenizer, AutoConfig, TFAutoModelForSequenceClassification
  14. import tensorflow as tf
  15.  
  16. raw_datasets = load_dataset("imdb")
  17.  
  18. tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
  19.  
  20. def tokenize_function(examples):
  21.     return tokenizer(examples["text"], padding="max_length", truncation=True)
  22.  
  23. tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
  24.  
  25. small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
  26. small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
  27. full_train_dataset = tokenized_datasets["train"]
  28. full_eval_dataset = tokenized_datasets["test"]
  29.  
  30. config = AutoConfig.from_pretrained('bert-base-cased')
  31. model = TFAutoModelForSequenceClassification.from_config(config)
  32.  
  33. tf_train_dataset = small_train_dataset.remove_columns(["text"]).with_format("tensorflow")
  34. tf_eval_dataset = small_eval_dataset.remove_columns(["text"]).with_format("tensorflow")
  35.  
  36. bsize = 8
  37. train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names}
  38. train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"]))
  39. train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(bsize)
  40.  
  41. eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}
  42. eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"]))
  43. eval_tf_dataset = eval_tf_dataset.batch(bsize)
  44.  
  45. model.compile(
  46.     optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
  47.     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  48.     metrics=tf.metrics.SparseCategoricalAccuracy(),
  49. )
  50.  
  51. model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=5)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement