Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/14_train_and_push_to_hub/scripts/train.py
Views: 2548
import argparse1import logging2import os3import random4import sys56import numpy as np7import torch8from datasets import load_from_disk, load_metric9from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments10from transformers.trainer_utils import get_last_checkpoint1112if __name__ == "__main__":1314parser = argparse.ArgumentParser()1516# hyperparameters sent by the client are passed as command-line arguments to the script.17parser.add_argument("--epochs", type=int, default=3)18parser.add_argument("--train_batch_size", type=int, default=32)19parser.add_argument("--eval_batch_size", type=int, default=64)20parser.add_argument("--warmup_steps", type=int, default=500)21parser.add_argument("--model_id", type=str)22parser.add_argument("--learning_rate", type=str, default=5e-5)23parser.add_argument("--fp16", type=bool, default=True)2425# Push to Hub Parameters26parser.add_argument("--push_to_hub", type=bool, default=True)27parser.add_argument("--hub_model_id", type=str, default=None)28parser.add_argument("--hub_strategy", type=str, default=None)29parser.add_argument("--hub_token", type=str, default=None)3031# Data, model, and output directories32parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])33parser.add_argument("--output_dir", type=str, default=os.environ["SM_MODEL_DIR"])34parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])35parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])36parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])3738args, _ = parser.parse_known_args()3940# make sure we have required parameters to push41if args.push_to_hub:42if args.hub_strategy is None:43raise ValueError("--hub_strategy is required when pushing to Hub")44if args.hub_token is None:45raise ValueError("--hub_token is required when pushing to Hub")4647# sets hub id if not provided48if args.hub_model_id is None:49args.hub_model_id = args.model_id.replace("/", "--")5051# Set up logging52logger = logging.getLogger(__name__)5354logging.basicConfig(55level=logging.getLevelName("INFO"),56handlers=[logging.StreamHandler(sys.stdout)],57format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",58)5960# load datasets61train_dataset = load_from_disk(args.training_dir)62test_dataset = load_from_disk(args.test_dir)6364logger.info(f" loaded train_dataset length is: {len(train_dataset)}")65logger.info(f" loaded test_dataset length is: {len(test_dataset)}")6667# define metrics and metrics function68metric = load_metric("accuracy")6970def compute_metrics(eval_pred):71predictions, labels = eval_pred72predictions = np.argmax(predictions, axis=1)73return metric.compute(predictions=predictions, references=labels)7475# Prepare model labels - useful in inference API76labels = train_dataset.features["labels"].names77num_labels = len(labels)78label2id, id2label = dict(), dict()79for i, label in enumerate(labels):80label2id[label] = str(i)81id2label[str(i)] = label8283# download model from model hub84model = AutoModelForSequenceClassification.from_pretrained(85args.model_id, num_labels=num_labels, label2id=label2id, id2label=id2label86)87tokenizer = AutoTokenizer.from_pretrained(args.model_id)8889# define training args90training_args = TrainingArguments(91output_dir=args.output_dir,92overwrite_output_dir=True if get_last_checkpoint(args.output_dir) is not None else False,93num_train_epochs=args.epochs,94per_device_train_batch_size=args.train_batch_size,95per_device_eval_batch_size=args.eval_batch_size,96warmup_steps=args.warmup_steps,97fp16=args.fp16,98evaluation_strategy="epoch",99save_strategy="epoch",100save_total_limit=2,101logging_dir=f"{args.output_data_dir}/logs",102learning_rate=float(args.learning_rate),103load_best_model_at_end=True,104metric_for_best_model="accuracy",105# push to hub parameters106push_to_hub=args.push_to_hub,107hub_strategy=args.hub_strategy,108hub_model_id=args.hub_model_id,109hub_token=args.hub_token,110)111112# create Trainer instance113trainer = Trainer(114model=model,115args=training_args,116compute_metrics=compute_metrics,117train_dataset=train_dataset,118eval_dataset=test_dataset,119tokenizer=tokenizer,120)121122# train model123trainer.train()124125# evaluate model126eval_result = trainer.evaluate(eval_dataset=test_dataset)127128# save best model, metrics and create model card129trainer.create_model_card(model_name=args.hub_model_id)130trainer.push_to_hub()131132# Saves the model to s3 uses os.environ["SM_MODEL_DIR"] to make sure checkpointing works133trainer.save_model(os.environ["SM_MODEL_DIR"])134135136