Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/01_getting_started_pytorch/scripts/train.py
Views: 2549
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer1from sklearn.metrics import accuracy_score, precision_recall_fscore_support2from datasets import load_from_disk3import random4import logging5import sys6import argparse7import os8import torch910if __name__ == "__main__":1112parser = argparse.ArgumentParser()1314# hyperparameters sent by the client are passed as command-line arguments to the script.15parser.add_argument("--epochs", type=int, default=3)16parser.add_argument("--train_batch_size", type=int, default=32)17parser.add_argument("--eval_batch_size", type=int, default=64)18parser.add_argument("--warmup_steps", type=int, default=500)19parser.add_argument("--model_name", type=str)20parser.add_argument("--learning_rate", type=str, default=5e-5)2122# Data, model, and output directories23parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])24parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])25parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])26parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])27parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])2829args, _ = parser.parse_known_args()3031# Set up logging32logger = logging.getLogger(__name__)3334logging.basicConfig(35level=logging.getLevelName("INFO"),36handlers=[logging.StreamHandler(sys.stdout)],37format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",38)3940# load datasets41train_dataset = load_from_disk(args.training_dir)42test_dataset = load_from_disk(args.test_dir)4344logger.info(f" loaded train_dataset length is: {len(train_dataset)}")45logger.info(f" loaded test_dataset length is: {len(test_dataset)}")4647# compute metrics function for binary classification48def compute_metrics(pred):49labels = pred.label_ids50preds = pred.predictions.argmax(-1)51precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")52acc = accuracy_score(labels, preds)53return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}5455# download model from model hub56model = AutoModelForSequenceClassification.from_pretrained(args.model_name)57tokenizer = AutoTokenizer.from_pretrained(args.model_name)5859# define training args60training_args = TrainingArguments(61output_dir=args.model_dir,62num_train_epochs=args.epochs,63per_device_train_batch_size=args.train_batch_size,64per_device_eval_batch_size=args.eval_batch_size,65warmup_steps=args.warmup_steps,66evaluation_strategy="epoch",67logging_dir=f"{args.output_data_dir}/logs",68learning_rate=float(args.learning_rate),69)7071# create Trainer instance72trainer = Trainer(73model=model,74args=training_args,75compute_metrics=compute_metrics,76train_dataset=train_dataset,77eval_dataset=test_dataset,78tokenizer=tokenizer,79)8081# train model82trainer.train()8384# evaluate model85eval_result = trainer.evaluate(eval_dataset=test_dataset)8687# writes eval result to file which can be accessed later in s3 ouput88with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:89print(f"***** Eval results *****")90for key, value in sorted(eval_result.items()):91writer.write(f"{key} = {value}\n")9293# Saves the model to s394trainer.save_model(args.model_dir)959697