CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/02_getting_started_tensorflow/scripts/train.py
Views: 2549
1
import argparse
2
import logging
3
import os
4
import sys
5
6
import tensorflow as tf
7
from datasets import load_dataset
8
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, DataCollatorWithPadding, create_optimizer
9
10
11
if __name__ == "__main__":
12
13
parser = argparse.ArgumentParser()
14
15
# Hyperparameters sent by the client are passed as command-line arguments to the script.
16
parser.add_argument("--epochs", type=int, default=3)
17
parser.add_argument("--train_batch_size", type=int, default=16)
18
parser.add_argument("--eval_batch_size", type=int, default=8)
19
parser.add_argument("--model_id", type=str)
20
parser.add_argument("--learning_rate", type=str, default=3e-5)
21
22
# Data, model, and output directories
23
parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
24
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
25
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
26
27
args, _ = parser.parse_known_args()
28
29
# Set up logging
30
logger = logging.getLogger(__name__)
31
32
logging.basicConfig(
33
level=logging.getLevelName("INFO"),
34
handlers=[logging.StreamHandler(sys.stdout)],
35
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
36
)
37
38
# Load tokenizer
39
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
40
41
# Load DatasetDict
42
dataset = load_dataset("imdb")
43
44
# Preprocess train dataset
45
def preprocess_function(examples):
46
return tokenizer(examples["text"], truncation=True)
47
48
encoded_dataset = dataset.map(preprocess_function, batched=True)
49
50
# define tokenizer_columns
51
# tokenizer_columns is the list of keys from the dataset that get passed to the TensorFlow model
52
tokenizer_columns = ["attention_mask", "input_ids"]
53
54
# convert to TF datasets
55
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
56
encoded_dataset["train"] = encoded_dataset["train"].rename_column("label", "labels")
57
tf_train_dataset = encoded_dataset["train"].to_tf_dataset(
58
columns=tokenizer_columns,
59
label_cols=["labels"],
60
shuffle=True,
61
batch_size=8,
62
collate_fn=data_collator,
63
)
64
encoded_dataset["test"] = encoded_dataset["test"].rename_column("label", "labels")
65
tf_validation_dataset = encoded_dataset["test"].to_tf_dataset(
66
columns=tokenizer_columns,
67
label_cols=["labels"],
68
shuffle=False,
69
batch_size=8,
70
collate_fn=data_collator,
71
)
72
73
# Prepare model labels - useful in inference API
74
labels = encoded_dataset["train"].features["labels"].names
75
num_labels = len(labels)
76
label2id, id2label = dict(), dict()
77
for i, label in enumerate(labels):
78
label2id[label] = str(i)
79
id2label[str(i)] = label
80
81
# download model from model hub
82
model = TFAutoModelForSequenceClassification.from_pretrained(
83
args.model_id, num_labels=num_labels, label2id=label2id, id2label=id2label
84
)
85
86
# create Adam optimizer with learning rate scheduling
87
batches_per_epoch = len(encoded_dataset["train"]) // args.train_batch_size
88
total_train_steps = int(batches_per_epoch * args.epochs)
89
90
optimizer, _ = create_optimizer(init_lr=args.learning_rate, num_warmup_steps=0, num_train_steps=total_train_steps)
91
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
92
93
# define metric and compile model
94
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
95
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
96
97
# Training
98
logger.info("*** Train ***")
99
train_results = model.fit(
100
tf_train_dataset,
101
epochs=args.epochs,
102
validation_data=tf_validation_dataset,
103
)
104
105
output_eval_file = os.path.join(args.output_data_dir, "train_results.txt")
106
107
with open(output_eval_file, "w") as writer:
108
logger.info("***** Train results *****")
109
logger.info(train_results)
110
for key, value in train_results.history.items():
111
logger.info(" %s = %s", key, value)
112
writer.write("%s = %s\n" % (key, value))
113
114
# Save result
115
model.save_pretrained(args.model_dir)
116
tokenizer.save_pretrained(args.model_dir)
117
118