Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #from https://huggingface.co/transformers/training.html
- import tensorflow as tf
- import numpy as np
- gpus = tf.config.experimental.list_physical_devices('GPU')
- if gpus:
- try:
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- except RuntimeError as e:
- print(e)
- from datasets import load_dataset
- from transformers import AutoTokenizer, AutoConfig, TFAutoModelForSequenceClassification
- import tensorflow as tf
- raw_datasets = load_dataset("imdb")
- tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
- def tokenize_function(examples):
- return tokenizer(examples["text"], padding="max_length", truncation=True)
- tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
- small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
- small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
- full_train_dataset = tokenized_datasets["train"]
- full_eval_dataset = tokenized_datasets["test"]
- config = AutoConfig.from_pretrained('bert-base-cased')
- model = TFAutoModelForSequenceClassification.from_config(config)
- tf_train_dataset = small_train_dataset.remove_columns(["text"]).with_format("tensorflow")
- tf_eval_dataset = small_eval_dataset.remove_columns(["text"]).with_format("tensorflow")
- bsize = 8
- train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names}
- train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"]))
- train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(bsize)
- eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}
- eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"]))
- eval_tf_dataset = eval_tf_dataset.batch(bsize)
- model.compile(
- optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
- metrics=tf.metrics.SparseCategoricalAccuracy(),
- )
- model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=5)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement