CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/26_document_ai_donut/scripts/train.py
Views: 2548
1
import os
2
import argparse
3
from transformers import (
4
AutoModelForCausalLM,
5
AutoTokenizer,
6
set_seed,
7
default_data_collator,
8
)
9
from datasets import load_from_disk
10
import torch
11
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DonutProcessor, VisionEncoderDecoderModel,VisionEncoderDecoderConfig
12
import shutil
13
import logging
14
import sys
15
import json
16
17
def parse_arge():
18
"""Parse the arguments."""
19
parser = argparse.ArgumentParser()
20
# add model id and dataset path argument
21
parser.add_argument(
22
"--model_id",
23
type=str,
24
default="naver-clova-ix/donut-base",
25
help="Model id to use for training.",
26
)
27
parser.add_argument("--special_tokens", type=str, default=None, help="JSON string of special tokens to add to tokenizer.")
28
parser.add_argument("--dataset_path", type=str, default="lm_dataset", help="Path to dataset.")
29
# add training hyperparameters for epochs, batch size, learning rate, and seed
30
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for.")
31
parser.add_argument(
32
"--per_device_train_batch_size",
33
type=int,
34
default=1,
35
help="Batch size to use for training.",
36
)
37
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate to use for training.")
38
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
39
parser.add_argument(
40
"--gradient_checkpointing",
41
type=bool,
42
default=False,
43
help="Path to deepspeed config file.",
44
)
45
parser.add_argument(
46
"--bf16",
47
type=bool,
48
default=True if torch.cuda.get_device_capability()[0] == 8 else False,
49
help="Whether to use bf16.",
50
)
51
args = parser.parse_known_args()
52
return args
53
54
55
def training_function(args):
56
# set seed
57
set_seed(args.seed)
58
59
# Set up logging
60
logger = logging.getLogger(__name__)
61
62
logging.basicConfig(
63
level=logging.getLevelName("INFO"),
64
handlers=[logging.StreamHandler(sys.stdout)],
65
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
66
)
67
68
# load datasets
69
train_dataset = load_from_disk(args.dataset_path)
70
image_size = list(torch.tensor(train_dataset[0]["pixel_values"][0]).shape) # height, width
71
logger.info(f"loaded train_dataset length is: {len(train_dataset)}")
72
73
# Load processor and set up new special tokens
74
processor = DonutProcessor.from_pretrained(args.model_id)
75
# add new special tokens to tokenizer and resize feature extractor
76
special_tokens = args.special_tokens.split(",")
77
processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
78
processor.feature_extractor.size = image_size[::-1] # should be (width, height)
79
processor.feature_extractor.do_align_long_axis = False
80
81
# Load model from huggingface.co
82
config = VisionEncoderDecoderConfig.from_pretrained(args.model_id, use_cache=False if args.gradient_checkpointing else True)
83
model = VisionEncoderDecoderModel.from_pretrained(args.model_id, config=config)
84
85
# Resize embedding layer to match vocabulary size & adjust our image size and output sequence lengths
86
model.decoder.resize_token_embeddings(len(processor.tokenizer))
87
model.config.encoder.image_size = image_size
88
model.config.decoder.max_length = len(max(train_dataset["labels"], key=len))
89
# Add task token for decoder to start
90
model.config.pad_token_id = processor.tokenizer.pad_token_id
91
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]
92
93
94
# Arguments for training
95
output_dir = "/tmp"
96
training_args = Seq2SeqTrainingArguments(
97
output_dir=output_dir,
98
num_train_epochs=args.epochs,
99
learning_rate=args.lr,
100
per_device_train_batch_size=args.per_device_train_batch_size,
101
bf16=True,
102
tf32=True,
103
gradient_checkpointing=args.gradient_checkpointing,
104
logging_steps=10,
105
save_total_limit=1,
106
evaluation_strategy="no",
107
save_strategy="epoch",
108
)
109
110
# Create Trainer
111
trainer = Seq2SeqTrainer(
112
model=model,
113
args=training_args,
114
train_dataset=train_dataset,
115
)
116
117
# Start training
118
trainer.train()
119
120
# save model and processor
121
trainer.model.save_pretrained("/opt/ml/model/")
122
processor.save_pretrained("/opt/ml/model/")
123
124
# copy inference script
125
os.makedirs("/opt/ml/model/code", exist_ok=True)
126
shutil.copyfile(
127
os.path.join(os.path.dirname(__file__), "inference.py"),
128
"/opt/ml/model/code/inference.py",
129
)
130
131
132
def main():
133
args, _ = parse_arge()
134
training_function(args)
135
136
137
if __name__ == "__main__":
138
main()
139
140