CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/26_document_ai_donut/scripts/inference.py
Views: 2549
1
from transformers import DonutProcessor, VisionEncoderDecoderModel
2
import torch
3
4
device = "cuda" if torch.cuda.is_available() else "cpu"
5
6
def model_fn(model_dir):
7
# Load our model from Hugging Face
8
processor = DonutProcessor.from_pretrained(model_dir)
9
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
10
11
# Move model to GPU
12
model.to(device)
13
14
return model, processor
15
16
17
def predict_fn(data, model_and_processor):
18
# unpack model and tokenizer
19
model, processor = model_and_processor
20
21
image = data.get("inputs")
22
pixel_values = processor.feature_extractor(image, return_tensors="pt").pixel_values
23
task_prompt = "<s>" # start of sequence token for decoder since we are not having a user prompt
24
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
25
26
# run inference
27
outputs = model.generate(
28
pixel_values.to(device),
29
decoder_input_ids=decoder_input_ids.to(device),
30
max_length=model.decoder.config.max_position_embeddings,
31
early_stopping=True,
32
pad_token_id=processor.tokenizer.pad_token_id,
33
eos_token_id=processor.tokenizer.eos_token_id,
34
use_cache=True,
35
num_beams=1,
36
bad_words_ids=[[processor.tokenizer.unk_token_id]],
37
return_dict_in_generate=True,
38
)
39
40
# process output
41
prediction = processor.batch_decode(outputs.sequences)[0]
42
prediction = processor.token2json(prediction)
43
44
return prediction
45
46
47