CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb
Views: 2542
Kernel: conda_pytorch_p39

Sentence Embeddings with Hugging Face Transformers, Sentence Transformers and Amazon SageMaker - Custom Inference for creating document embeddings with Hugging Face's Transformers

Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a real-time inference endpoint running a Sentence Transformers for document embeddings. Currently, the SageMaker Hugging Face Inference Toolkit supports the pipeline feature from Transformers for zero-code deployment. This means you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefore we only need to provide an environment variable HF_TASK and HF_MODEL_ID when creating our endpoint and the Inference Toolkit will take care of it. This is a great feature if you are working with existing pipelines.

If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via an inference.py script. The Hugging Face Inference Toolkit allows the user to override the default methods of the HuggingFaceHandlerService.

The custom module can override the following methods:

  • model_fn(model_dir) overrides the default method for loading a model. The return value model will be used in thepredict_fn for predictions.

    • model_dir is the the path to your unzipped model.tar.gz.

  • input_fn(input_data, content_type) overrides the default method for pre-processing. The return value data will be used in predict_fn for predictions. The inputs are:

    • input_data is the raw body of your request.

    • content_type is the content type from the request header.

  • predict_fn(processed_data, model) overrides the default method for predictions. The return value predictions will be used in output_fn.

    • model returned value from model_fn methond

    • processed_data returned value from input_fn method

  • output_fn(prediction, accept) overrides the default method for post-processing. The return value result will be the response to your request (e.g.JSON). The inputs are:

    • predictions is the result from predict_fn.

    • accept is the return accept type from the HTTP Request, e.g. application/json.

In this example are we going to use Sentence Transformers to create sentence embeddings using a mean pooling layer on the raw representation.

NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances

Development Environment and Permissions

Installation

%pip install sagemaker --upgrade

Install git and git-lfs

# For notebook instances (Amazon Linux) !sudo yum update -y !curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash !sudo yum install git-lfs git -y # For other environments (Ubuntu) !sudo apt-get update -y !curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash !sudo apt-get install git-lfs git -y

Permissions

If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.

import sagemaker import boto3 sess = sagemaker.Session() # sagemaker session bucket -> used for uploading data, models and logs # sagemaker will automatically create this bucket if it not exists sagemaker_session_bucket=None if sagemaker_session_bucket is None and sess is not None: # set to default bucket if a bucket name is not given sagemaker_session_bucket = sess.default_bucket() try: role = sagemaker.get_execution_role() except ValueError: iam = boto3.client('iam') role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn'] sess = sagemaker.Session(default_bucket=sagemaker_session_bucket) print(f"sagemaker role arn: {role}") print(f"sagemaker bucket: {sess.default_bucket()}") print(f"sagemaker session region: {sess.boto_region_name}")
Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.
sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role sagemaker bucket: sagemaker-us-east-1-558105141721 sagemaker session region: us-east-1

Create custom an inference.py script

To use the custom inference script, you need to create an inference.py script. In our example, we are going to overwrite the model_fn to load our sentence transformer correctly and the predict_fn to apply mean pooling.

We are going to use the sentence-transformers/all-MiniLM-L6-v2 model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.

!mkdir code
%%writefile code/inference.py from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F # Helper: Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def model_fn(model_dir): # Load model from HuggingFace Hub tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModel.from_pretrained(model_dir) return model, tokenizer def predict_fn(data, model_and_tokenizer): # destruct model and tokenizer model, tokenizer = model_and_tokenizer # Tokenize sentences sentences = data.pop("inputs", data) encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings with torch.no_grad(): model_output = model(**encoded_input) # Perform pooling sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) # Normalize embeddings sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) # return dictonary, which will be json serializable return {"vectors": sentence_embeddings[0].tolist()}
Overwriting code/inference.py

Create model.tar.gz with inference script and model

To use our inference.py we need to bundle it into a model.tar.gz archive with all our model-artifcats, e.g. pytorch_model.bin. The inference.py script will be placed into a code/ folder. We will use git and git-lfs to easily download our model from hf.co/models and upload it to Amazon S3 so we can use it when creating our SageMaker endpoint.

repository = "sentence-transformers/all-MiniLM-L6-v2" model_id=repository.split("/")[-1] s3_location=f"s3://{sess.default_bucket()}/custom_inference/{model_id}/model.tar.gz"
  1. Download the model from hf.co/models with git clone.

!git lfs install !git clone https://huggingface.co/$repository
Updated git hooks. Git LFS initialized. Cloning into 'all-MiniLM-L6-v2'... remote: Enumerating objects: 25, done. remote: Counting objects: 100% (25/25), done. remote: Compressing objects: 100% (23/23), done. remote: Total 25 (delta 3), reused 0 (delta 0).00 KiB/s Unpacking objects: 100% (25/25), 308.60 KiB | 454.00 KiB/s, done.
  1. copy inference.py into the code/ directory of the model directory.

!cp -r code/ $model_id/code/
  1. Create a model.tar.gz archive with all the model artifacts and the inference.py script.

%cd $model_id !tar zcvf model.tar.gz *
/Users/philipp/.Trash/all-MiniLM-L6-v2/all-MiniLM-L6-v2 a 1_Pooling a 1_Pooling/config.json a README.md a code a code/inference.py a config.json a config_sentence_transformers.json a data_config.json a modules.json a pytorch_model.bin a sentence_bert_config.json a special_tokens_map.json a tokenizer.json a tokenizer_config.json a train_script.py a vocab.txt
  1. Upload the model.tar.gz to Amazon S3:

!aws s3 cp model.tar.gz $s3_location
upload: ./model.tar.gz to s3://sagemaker-us-east-1-558105141721/custom_inference/all-MiniLM-L6-v2/model.tar.gz

Create custom HuggingfaceModel

After we have created and uploaded our model.tar.gz archive to Amazon S3. Can we create a custom HuggingfaceModel class. This class will be used to create and deploy our SageMaker endpoint.

from sagemaker.huggingface.model import HuggingFaceModel # create Hugging Face Model Class huggingface_model = HuggingFaceModel( model_data=s3_location, # path to your model and script role=role, # iam role with permissions to create an Endpoint transformers_version="4.26", # transformers version used pytorch_version="1.13", # pytorch version used py_version='py39', # python version used ) # deploy the endpoint endpoint predictor = huggingface_model.deploy( initial_instance_count=1, instance_type="ml.g4dn.xlarge" )
-----------!

Request Inference Endpoint using the HuggingfacePredictor

The .deploy() returns an HuggingFacePredictor object which can be used to request inference.

data = { "inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .", } res = predictor.predict(data=data) print(res)
{'vectors': [0.005078191868960857, -0.0036594511475414038, 0.016988741233944893, -0.0015786211006343365, 0.030203675851225853, 0.09331899881362915, -0.0235157310962677, 0.011795195750892162, 0.03421774506568909, -0.027907833456993103, -0.03260169178247452, 0.0679800882935524, 0.015223750844597816, 0.025948498398065567, -0.07854384928941727, -0.0023915462661534548, 0.10089637339115143, 0.0014981384156271815, -0.017778029665350914, 0.005812637507915497, 0.02445339597761631, -0.0710371807217598, 0.04755859822034836, 0.026360979303717613, -0.05716250091791153, -0.0940014198422432, 0.047949012368917465, 0.008600219152867794, 0.03297032043337822, -0.06984368711709976, -0.0552142858505249, -0.03234352916479111, -0.0003443364112172276, 0.012479404918849468, -0.07419367134571075, 0.08545409888029099, 0.019597113132476807, 0.005851477384567261, -0.08256848156452179, 0.010150186717510223, 0.028275227174162865, -0.0016121627995744348, 0.04174523428082466, -0.009756717830896378, 0.03546829894185066, -0.0673336461186409, 0.013293622992932796, -0.047809384763240814, -0.02249010093510151, 0.028243854641914368, -0.08043544739484787, -0.01009676605463028, -0.03514788672327995, -0.021383730694651604, -0.002246067626401782, -0.015066167339682579, 0.04234122484922409, -0.040479838848114014, 0.00787312351167202, -0.04465996101498604, 0.010779906995594501, 0.0038497159257531166, -0.027719097211956978, -0.007967316545546055, 0.02942546270787716, -0.012327964417636395, 0.0050182887353003025, 0.06450540572404861, 0.03108026832342148, 0.042792391031980515, 0.023805316537618637, -0.01616135612130165, 0.02578461915254593, -0.08669176697731018, -0.044727668166160583, 7.097257184796035e-05, -0.10924965143203735, -0.10867254436016083, -0.03139006346464157, -0.03511088714003563, 0.08570166677236557, -0.134019672870636, -0.0005924605648033321, 0.029533952474594116, 0.012721308507025242, 0.02152288891375065, 0.0707324892282486, -0.11056605726480484, -0.1083742305636406, 0.0982309952378273, -0.039475709199905396, -0.05996376648545265, -0.10398901998996735, 0.03040657937526703, -0.03018292225897312, -0.03471128270030022, -0.06378458440303802, 0.016372960060834885, 0.0583597756922245, 0.012307470664381981, 0.04363206401467323, -0.031246762722730637, -0.09203378111124039, -0.0062785972841084, 0.015498220920562744, -0.07184164226055145, 0.012648160569369793, 0.014564670622348785, -0.08191244304180145, 0.023379981517791748, -0.011096887290477753, 0.0394676998257637, -0.033372823148965836, 0.041654154658317566, 0.0863155946135521, 0.015705395489931107, 0.01734650880098343, 0.08271384239196777, 0.022032614797353745, 0.03559378534555435, 0.12214990705251694, 0.032827410846948624, 0.026021108031272888, -0.019847815856337547, 0.010051277466118336, -0.04892867058515549, -0.0174998976290226, -1.4977462088666326e-33, -0.01998828910291195, -0.020090218633413315, 0.009214007295668125, 0.029388802126049995, 0.01617312990128994, 0.003455288475379348, -0.07258066534996033, 0.049684278666973114, -0.06154271960258484, 0.05080917105078697, 0.05352963134646416, -0.011941409669816494, -0.0028067785315215588, -0.041576843708753586, -0.010775507427752018, 0.00046661923988722265, 0.004454561043530703, 0.030003147199749947, -0.0516991950571537, -0.030697643756866455, -0.07532348483800888, 0.05465441197156906, -0.0385969914495945, -0.04381357878446579, -0.03235914930701256, 0.017494583502411842, 0.005240216851234436, 0.06198848783969879, -0.03355488181114197, 0.011264801025390625, -0.02115759812295437, 0.00838891975581646, -0.058978889137506485, -0.00011408641876187176, 0.05079993978142738, 0.015300493687391281, -0.07043343037366867, -0.07872467488050461, 0.09050456434488297, 0.03952907398343086, -0.07477521151304245, 0.03615942969918251, -0.058201417326927185, 0.0326484851539135, -0.03198658302426338, 0.11224830150604248, -0.016622459515929222, 0.0504615381360054, -0.04651995375752449, 0.1277347207069397, 0.03776664286851883, 0.05948572978377342, 0.09149560332298279, -0.009857898578047752, 0.004627745598554611, 0.03188807889819145, 0.062271688133478165, -0.0659433975815773, 0.0032127737067639828, -0.13898129761219025, 0.026403773576021194, 0.08804035186767578, -0.05001967027783394, 0.05326379835605621, -0.02196440100669861, 0.07656972110271454, 0.013867619447410107, -0.016544628888368607, -0.009327870793640614, 0.021883144974708557, -0.1560947597026825, -0.07534021139144897, -0.01896633207798004, 0.012034989893436432, -0.07331383228302002, -0.04332052916288376, -0.03353505954146385, 0.007872307673096657, 0.16191385686397552, -0.058967869728803635, 0.024201923981308937, 0.011731469072401524, -0.002475024200975895, -0.060298558324575424, -0.023722389712929726, -0.04882300645112991, 0.000707246595993638, -0.018090907484292984, 0.07239993661642075, 0.07933493703603745, 0.054174549877643585, -0.03342485427856445, -0.007864750921726227, 0.06494550406932831, -0.08771026879549026, 1.13459770849573e-33, 0.06040865182876587, 0.006845973432064056, -0.09519106149673462, -0.004926742985844612, 0.02894597128033638, -0.0077415574342012405, -0.05669841915369034, -0.034497782588005066, 0.09411472827196121, 0.0011957630049437284, -0.03672650456428528, 0.023257385939359665, -0.029259465634822845, -0.004881837405264378, -0.034621454775333405, -0.1123257502913475, 0.041878167539834976, 0.01935793086886406, 0.019774673506617546, 0.0033800536766648293, 0.04810955002903938, -0.043293364346027374, -0.019849350675940514, -0.024460462853312492, 0.011674574576318264, 0.028871286660432816, -0.04594291001558304, -0.009591681882739067, -0.020649896934628487, -0.0767439752817154, 0.06008455529808998, -0.07102784514427185, -0.03325150907039642, -0.07066744565963745, -0.07285013049840927, 0.06852841377258301, 0.032675426453351974, -0.015307767316699028, -0.03120141103863716, -0.0008060619584284723, -0.012935955077409744, 0.01687614619731903, 0.010606919415295124, 0.05316408351063728, -0.016209596768021584, 0.05059502646327019, -0.016619250178337097, -0.003106643445789814, -0.09400973469018936, 0.02362005040049553, -0.1493453085422516, 0.03363995999097824, -0.013002770021557808, -0.0411999374628067, -0.03762894868850708, 0.01735512912273407, -0.02544626034796238, -0.015723178163170815, 0.007998578250408173, 0.04340173304080963, 0.006307568401098251, -0.031614888459444046, -0.03868135064840317, -0.11168476939201355, 0.04688170179724693, 0.02938792295753956, 0.007106451783329248, -0.023254472762346268, 0.006188348866999149, 0.032097551971673965, 0.02284681424498558, -0.020912854000926018, -0.016115304082632065, 0.006232560612261295, -0.06727242469787598, 0.0027730280999094248, -0.04707656428217888, -0.03735049441456795, 0.026144297793507576, -0.013619091361761093, -0.005712081212550402, -0.04333459213376045, -0.008567489683628082, -0.0026371825952082872, -0.04714951291680336, 0.1506747603416443, 0.060538701713085175, 0.015910591930150986, 0.0021603393834084272, 0.09120813012123108, 0.10193410515785217, 0.04816991090774536, 0.07890739291906357, -0.05583663284778595, -0.02227107249200344, -2.478202887346015e-08, -0.08490563929080963, 0.04434036836028099, 0.02475418709218502, -0.024806825444102287, 0.00536795100197196, -0.06101489067077637, 0.014922979287803173, 0.04093354195356369, 0.03936637192964554, 0.04489367827773094, 0.012824231758713722, -0.03051156736910343, 0.0662570372223854, 0.04904399439692497, 0.004838698077946901, 0.07400422543287277, 0.03470872715115547, 0.037787146866321564, -0.043043263256549835, 0.04372495785355568, 0.023403732106089592, 0.057728372514247894, 0.034502316266298294, -0.049777042120695114, -0.0041667199693620205, 0.06382499635219574, -0.007370579522103071, -0.002130263252183795, -0.04700297489762306, 0.10623563826084137, -5.87037175137084e-05, -0.012606821022927761, 0.03633716702461243, 0.024944987148046494, -0.06500178575515747, 0.07670733332633972, 0.01752745360136032, 0.019638163968920708, 0.05920606851577759, 0.021030694246292114, 0.033589065074920654, 0.014452814124524593, 0.030615368857979774, 0.13622330129146576, 0.0162414088845253, 0.07696809619665146, 0.10586545616388321, 0.06321518868207932, -0.06497083604335785, 0.0035124991554766893, 0.03836303576827049, -0.049263447523117065, -0.0939357802271843, 0.04310446232557297, 0.047002870589494705, 0.02352922037243843, 0.06475073844194412, 0.12606267631053925, -0.03936544433236122, 0.0033126939088106155, -0.005963532254099846, 0.01087606605142355, -0.006803632713854313, 0.05783495306968689]}

Delete model and endpoint

To clean up, we can delete the model and endpoint.

predictor.delete_model() predictor.delete_endpoint()