Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/hugging_face_keras_integration.py
3293 views
1
"""
2
Title: Loading Hugging Face Transformers Checkpoints
3
Author: [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)
4
Date created: 2025/06/17
5
Last modified: 2025/07/22
6
Description: How to load and run inference from KerasHub model checkpoints hosted on the HuggingFace Hub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
KerasHub has built-in converters for HuggingFace's `.safetensors` models.
14
Loading model weights from HuggingFace is therefore no more difficult than
15
using KerasHub's own presets.
16
17
### KerasHub built-in HuggingFace transformers converters
18
19
KerasHub simplifies the use of HuggingFace Transformers models through its
20
built-in converters. These converters automatically handle the process of translating
21
HuggingFace model checkpoints into a format that's compatible with the Keras ecosystem.
22
This means you can seamlessly load a wide variety of pretrained models from the HuggingFace
23
Hub directly into KerasHub with just a few lines of code.
24
25
Key advantages of using KerasHub converters:
26
27
- **Ease of Use**: Load HuggingFace models without manual conversion steps.
28
- **Broad Compatibility**: Access a vast range of models available on the HuggingFace Hub.
29
- **Seamless Integration**: Work with these models using familiar Keras APIs for training,
30
evaluation, and inference.
31
32
Fortunately, all of this happens behind the scenes, so you can focus on using
33
the models rather than managing the conversion process!
34
35
## Setup
36
37
Before you begin, make sure you have the necessary libraries installed.
38
You'll primarily need `keras` and `keras_hub`.
39
40
**Note:** Changing the backend after Keras has been imported might not work as expected.
41
Ensure `KERAS_BACKEND` is set at the beginning of your script. Similarly, when working
42
outside of colab, you might use `os.environ["HF_TOKEN"] = "<YOUR_HF_TOKEN>"` to authenticate
43
to HuggingFace. Set your `HF_TOKEN` as "Colab secret", when working with
44
Google Colab.
45
"""
46
47
import os
48
49
os.environ["KERAS_BACKEND"] = "jax" # "tensorflow" or "torch"
50
51
import keras
52
import keras_hub
53
54
"""
55
### Changing precision
56
57
To perform inference and training on affordable hardware, you can adjust your
58
model’s precision by configuring it through `keras.config` as follows
59
60
"""
61
62
import keras
63
64
keras.config.set_dtype_policy("bfloat16")
65
66
"""
67
## Loading a HuggingFace model
68
69
KerasHub allows you to easily load models from HuggingFace Transformers.
70
Here's an example of how to load a Gemma causal language model.
71
In this particular case, you will need to consent to Google's license on
72
HuggingFace for being able to download model weights.
73
74
"""
75
76
# not a keras checkpoint, it is a HF transformer checkpoint
77
78
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
79
80
"""
81
Let us try running some inference
82
83
"""
84
85
gemma_lm.generate("I want to say", max_length=30)
86
87
"""
88
### Fine-tuning a Gemma Transformer checkpoint using the Keras `model.fit(...)` API
89
90
Once you have loaded HuggingFace weights, you can use the instantiated model
91
just like any other KerasHub model. For instance, you might fine-tune the model
92
on your own data like so:
93
"""
94
95
features = ["The quick brown fox jumped.", "I forgot my homework."]
96
gemma_lm.fit(x=features, batch_size=2)
97
98
"""
99
### Saving and uploading the new checkpoint
100
101
To store and share your fine-tuned model, KerasHub makes it easy to save or
102
upload it using standard methods. You can do this through familiar commands
103
such as:
104
"""
105
106
HF_USERNAME = "<YOUR_HF_USERNAME>" # provide your hf username
107
gemma_lm.save_to_preset("./gemma-2b-finetuned")
108
keras_hub.upload_preset(f"hf://{HF_USERNAME}/gemma-2b-finetune", "./gemma-2b-finetuned")
109
110
"""
111
By uploading your preset, you can then load it from anywhere using:
112
`loaded_model = keras_hub.models.GemmaCausalLM.from_preset("hf://YOUR_HF_USERNAME/gemma-2b-finetuned")`
113
114
For a comprehensive, step-by-step guide on uploading your model, refer to the official KerasHub upload documentation.
115
You can find all the details here: [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)
116
117
By integrating HuggingFace Transformers, KerasHub significantly expands your access to pretrained models.
118
The Hugging Face Hub now hosts well over 750k+ model checkpoints across various domains such as NLP,
119
Computer Vision, Audio, and more. Of these, approximately 400K models are currently compatible with KerasHub,
120
giving you access to a vast and diverse selection of state-of-the-art architectures for your projects.
121
122
With KerasHub, you can:
123
124
- **Tap into State-of-the-Art Models**: Easily experiment with the latest
125
architectures and pretrained weights from the research community and industry.
126
- **Reduce Development Time**: Leverage existing models instead of training from scratch,
127
saving significant time and computational resources.
128
- **Enhance Model Capabilities**: Find specialized models for a wide array of tasks,
129
from text generation and translation to image segmentation and object detection.
130
131
This seamless access empowers you to build more powerful and sophisticated AI applications with Keras.
132
133
## Use a wider range of frameworks
134
135
Keras 3, and by extension KerasHub, is designed for multi-framework compatibility.
136
This means you can run your models with different backend frameworks like JAX, TensorFlow, and PyTorch.
137
This flexibility allows you to:
138
139
- **Choose the Best Backend for Your Needs**: Select a backend based on performance characteristics,
140
hardware compatibility (e.g., TPUs with JAX), or existing team expertise.
141
- **Interoperability**: More easily integrate KerasHub models into existing
142
workflows that might be built on TensorFlow or PyTorch.
143
- **Future-Proofing**: Adapt to evolving framework landscapes without
144
rewriting your core model logic.
145
146
## Run transformer models in JAX backend and on TPUs
147
148
To experiment with a model using JAX, you can utilize Keras by setting its backend to JAX.
149
By switching Keras’s backend before model construction, and ensuring your environment is connected to a TPU runtime.
150
Keras will then automatically leverage JAX’s TPU support,
151
allowing your model to train efficiently on TPU hardware without further code changes.
152
"""
153
154
import os
155
156
os.environ["KERAS_BACKEND"] = "jax"
157
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
158
159
"""
160
## Additional Examples
161
162
### Generation
163
164
Here’s an example using Llama: Loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend.
165
"""
166
import os
167
168
os.environ["KERAS_BACKEND"] = "jax"
169
170
from keras_hub.models import Llama3CausalLM
171
172
# Get the model
173
causal_lm = Llama3CausalLM.from_preset("hf://NousResearch/Hermes-2-Pro-Llama-3-8B")
174
175
prompts = [
176
"""<|im_start|>system
177
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
178
<|im_start|>user
179
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
180
<|im_start|>assistant""",
181
]
182
183
# Generate from the model
184
causal_lm.generate(prompts, max_length=30)[0]
185
186
"""
187
## Comparing to Transformers
188
189
In the following table, we have compiled a detailed comparison of HuggingFace's Transformers library with KerasHub:
190
191
| Feature | HF Transformers | KerasHub |
192
|----------------------------|-------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
193
| Frameworks supported | PyTorch | JAX, PyTorch, TensorFlow |
194
| Trainer | HF Trainer | Keras `model.fit(...)` — supports nearly all features such as distributed training, learning rate scheduling, optimizer selection, etc. |
195
| Tokenizers | `AutoTokenizer` | [KerasHub Tokenizers](https://keras.io/keras_hub/api/tokenizers/) |
196
| Autoclass | `auto` keyword | KerasHub automatically [detects task-specific classes](https://x.com/fchollet/status/1922719664859381922) |
197
| Model loading | `AutoModel.from_pretrained()` | `keras_hub.models.<Task>.from_preset()`<br><br>KerasHub uses task-specific classes (e.g., `CausalLM`, `Classifier`, `Backbone`) with a `from_preset()` method to load pretrained models, analogous to HuggingFace’s method.<br><br>Supports HF URLs, Kaggle URLs, and local directories |
198
| Model saving | `model.save_pretrained()`<br>`tokenizer.save_pretrained()` | `model.save_to_preset()` — saves the model (including tokenizer/preprocessor) into a local directory (preset). All components needed for reloading or uploading are saved. |
199
| Model uploading | Uploading weights to HF platform | [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)<br>[Keras on Hugging Face](https://huggingface.co/keras) |
200
| Weights file sharding | Weights file sharding | Large model weights are sharded for efficient upload/download |
201
| PEFT | Uses [HuggingFace PEFT](https://github.com/huggingface/peft) | Built-in LoRA support:<br>`backbone.enable_lora(rank=n)`<br>`backbone.save_lora_weights(filepath)`<br>`backbone.load_lora_weights(filepath)` |
202
| Core model abstractions | `PreTrainedModel`, `AutoModel`, task-specific models | `Backbone`, `Preprocessor`, `Task` |
203
| Model configs | `PretrainedConfig`: Base class for model configurations | Configurations stored as multiple JSON files in preset directory: `config.json`, `preprocessor.json`, `task.json`, `tokenizer.json`, etc. |
204
| Preprocessing | Tokenizers/preprocessors often handled separately, then passed to the model | Built into task-specific models |
205
| Mixed precision training | Via training arguments | Keras global policy setting |
206
| Compatibility with SafeTensors | Default weights format | Of the 770k+ SafeTensors models on HF, those with a matching architecture in KerasHub can be loaded using `keras_hub.models.X.from_preset()` |
207
208
209
Go try loading other model weights! You can find more options on HuggingFace
210
and use them with `from_preset("hf://<namespace>/<model-name>")`.
211
212
Happy experimenting!
213
"""
214
215