Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/09_image_classification_vision_transformer/scripts/train.py
Views: 2548
from transformers import ViTForImageClassification, Trainer, TrainingArguments,default_data_collator,ViTFeatureExtractor1from datasets import load_from_disk,load_metric2import random3import logging4import sys5import argparse6import os7import numpy as np8import subprocess910subprocess.run([11"git",12"config",13"--global",14"user.email",15"[email protected]",16], check=True)17subprocess.run([18"git",19"config",20"--global",21"user.name",22"sagemaker",23], check=True)242526if __name__ == "__main__":2728parser = argparse.ArgumentParser()2930# hyperparameters sent by the client are passed as command-line arguments to the script.31parser.add_argument("--model_name", type=str)32parser.add_argument("--output_dir", type=str,default="/opt/ml/model")33parser.add_argument("--extra_model_name", type=str,default="sagemaker")34parser.add_argument("--dataset", type=str,default="cifar10")35parser.add_argument("--task", type=str,default="image-classification")36parser.add_argument("--use_auth_token", type=str, default="")3738parser.add_argument("--num_train_epochs", type=int, default=3)39parser.add_argument("--per_device_train_batch_size", type=int, default=32)40parser.add_argument("--per_device_eval_batch_size", type=int, default=64)41parser.add_argument("--warmup_steps", type=int, default=500)42parser.add_argument("--weight_decay", type=float, default=0.01)43parser.add_argument("--learning_rate", type=str, default=2e-5)4445parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])46parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])4748args, _ = parser.parse_known_args()4950# Set up logging51logger = logging.getLogger(__name__)5253logging.basicConfig(54level=logging.getLevelName("INFO"),55handlers=[logging.StreamHandler(sys.stdout)],56format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",57)5859# load datasets60train_dataset = load_from_disk(args.training_dir)61test_dataset = load_from_disk(args.test_dir)62num_classes = train_dataset.features["label"].num_classes636465logger.info(f" loaded train_dataset length is: {len(train_dataset)}")66logger.info(f" loaded test_dataset length is: {len(test_dataset)}")6768metric_name = "accuracy"69# compute metrics function for binary classification7071metric = load_metric(metric_name)7273def compute_metrics(eval_pred):74predictions, labels = eval_pred75predictions = np.argmax(predictions, axis=1)76return metric.compute(predictions=predictions, references=labels)7778# download model from model hub79model = ViTForImageClassification.from_pretrained(args.model_name,num_labels=num_classes)8081# change labels82id2label = {key:train_dataset.features["label"].names[index] for index,key in enumerate(model.config.id2label.keys())}83label2id = {train_dataset.features["label"].names[index]:value for index,value in enumerate(model.config.label2id.values())}84model.config.id2label = id2label85model.config.label2id = label2id868788# define training args89training_args = TrainingArguments(90output_dir=args.output_dir,91num_train_epochs=args.num_train_epochs,92per_device_train_batch_size=args.per_device_train_batch_size,93per_device_eval_batch_size=args.per_device_eval_batch_size,94warmup_steps=args.warmup_steps,95weight_decay=args.weight_decay,96evaluation_strategy="epoch",97logging_dir=f"{args.output_dir}/logs",98learning_rate=float(args.learning_rate),99load_best_model_at_end=True,100metric_for_best_model=metric_name,101)102103104# create Trainer instance105trainer = Trainer(106model=model,107args=training_args,108compute_metrics=compute_metrics,109train_dataset=train_dataset,110eval_dataset=test_dataset,111data_collator=default_data_collator,112)113114# train model115trainer.train()116117# evaluate model118eval_result = trainer.evaluate(eval_dataset=test_dataset)119120# writes eval result to file which can be accessed later in s3 ouput121with open(os.path.join(args.output_dir, "eval_results.txt"), "w") as writer:122print(f"***** Eval results *****")123for key, value in sorted(eval_result.items()):124writer.write(f"{key} = {value}\n")125126# Saves the model to s3127trainer.save_model(args.output_dir)128129if args.use_auth_token != "":130kwargs = {131"finetuned_from": args.model_name.split("/")[1],132"tags": "image-classification",133"dataset": args.dataset,134}135repo_name = (136f"{args.model_name.split('/')[1]}-{args.task}"137if args.extra_model_name == ""138else f"{args.model_name.split('/')[1]}-{args.task}-{args.extra_model_name}"139)140141trainer.push_to_hub(142repo_name=repo_name,143use_auth_token=args.use_auth_token,144**kwargs,145)146147148