Path: blob/master/guides/keras_hub/hugging_face_keras_integration.py
3293 views
"""1Title: Loading Hugging Face Transformers Checkpoints2Author: [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)3Date created: 2025/06/174Last modified: 2025/07/225Description: How to load and run inference from KerasHub model checkpoints hosted on the HuggingFace Hub.6Accelerator: GPU7"""89"""10## Introduction1112KerasHub has built-in converters for HuggingFace's `.safetensors` models.13Loading model weights from HuggingFace is therefore no more difficult than14using KerasHub's own presets.1516### KerasHub built-in HuggingFace transformers converters1718KerasHub simplifies the use of HuggingFace Transformers models through its19built-in converters. These converters automatically handle the process of translating20HuggingFace model checkpoints into a format that's compatible with the Keras ecosystem.21This means you can seamlessly load a wide variety of pretrained models from the HuggingFace22Hub directly into KerasHub with just a few lines of code.2324Key advantages of using KerasHub converters:2526- **Ease of Use**: Load HuggingFace models without manual conversion steps.27- **Broad Compatibility**: Access a vast range of models available on the HuggingFace Hub.28- **Seamless Integration**: Work with these models using familiar Keras APIs for training,29evaluation, and inference.3031Fortunately, all of this happens behind the scenes, so you can focus on using32the models rather than managing the conversion process!3334## Setup3536Before you begin, make sure you have the necessary libraries installed.37You'll primarily need `keras` and `keras_hub`.3839**Note:** Changing the backend after Keras has been imported might not work as expected.40Ensure `KERAS_BACKEND` is set at the beginning of your script. Similarly, when working41outside of colab, you might use `os.environ["HF_TOKEN"] = "<YOUR_HF_TOKEN>"` to authenticate42to HuggingFace. Set your `HF_TOKEN` as "Colab secret", when working with43Google Colab.44"""4546import os4748os.environ["KERAS_BACKEND"] = "jax" # "tensorflow" or "torch"4950import keras51import keras_hub5253"""54### Changing precision5556To perform inference and training on affordable hardware, you can adjust your57model’s precision by configuring it through `keras.config` as follows5859"""6061import keras6263keras.config.set_dtype_policy("bfloat16")6465"""66## Loading a HuggingFace model6768KerasHub allows you to easily load models from HuggingFace Transformers.69Here's an example of how to load a Gemma causal language model.70In this particular case, you will need to consent to Google's license on71HuggingFace for being able to download model weights.7273"""7475# not a keras checkpoint, it is a HF transformer checkpoint7677gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")7879"""80Let us try running some inference8182"""8384gemma_lm.generate("I want to say", max_length=30)8586"""87### Fine-tuning a Gemma Transformer checkpoint using the Keras `model.fit(...)` API8889Once you have loaded HuggingFace weights, you can use the instantiated model90just like any other KerasHub model. For instance, you might fine-tune the model91on your own data like so:92"""9394features = ["The quick brown fox jumped.", "I forgot my homework."]95gemma_lm.fit(x=features, batch_size=2)9697"""98### Saving and uploading the new checkpoint99100To store and share your fine-tuned model, KerasHub makes it easy to save or101upload it using standard methods. You can do this through familiar commands102such as:103"""104105HF_USERNAME = "<YOUR_HF_USERNAME>" # provide your hf username106gemma_lm.save_to_preset("./gemma-2b-finetuned")107keras_hub.upload_preset(f"hf://{HF_USERNAME}/gemma-2b-finetune", "./gemma-2b-finetuned")108109"""110By uploading your preset, you can then load it from anywhere using:111`loaded_model = keras_hub.models.GemmaCausalLM.from_preset("hf://YOUR_HF_USERNAME/gemma-2b-finetuned")`112113For a comprehensive, step-by-step guide on uploading your model, refer to the official KerasHub upload documentation.114You can find all the details here: [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)115116By integrating HuggingFace Transformers, KerasHub significantly expands your access to pretrained models.117The Hugging Face Hub now hosts well over 750k+ model checkpoints across various domains such as NLP,118Computer Vision, Audio, and more. Of these, approximately 400K models are currently compatible with KerasHub,119giving you access to a vast and diverse selection of state-of-the-art architectures for your projects.120121With KerasHub, you can:122123- **Tap into State-of-the-Art Models**: Easily experiment with the latest124architectures and pretrained weights from the research community and industry.125- **Reduce Development Time**: Leverage existing models instead of training from scratch,126saving significant time and computational resources.127- **Enhance Model Capabilities**: Find specialized models for a wide array of tasks,128from text generation and translation to image segmentation and object detection.129130This seamless access empowers you to build more powerful and sophisticated AI applications with Keras.131132## Use a wider range of frameworks133134Keras 3, and by extension KerasHub, is designed for multi-framework compatibility.135This means you can run your models with different backend frameworks like JAX, TensorFlow, and PyTorch.136This flexibility allows you to:137138- **Choose the Best Backend for Your Needs**: Select a backend based on performance characteristics,139hardware compatibility (e.g., TPUs with JAX), or existing team expertise.140- **Interoperability**: More easily integrate KerasHub models into existing141workflows that might be built on TensorFlow or PyTorch.142- **Future-Proofing**: Adapt to evolving framework landscapes without143rewriting your core model logic.144145## Run transformer models in JAX backend and on TPUs146147To experiment with a model using JAX, you can utilize Keras by setting its backend to JAX.148By switching Keras’s backend before model construction, and ensuring your environment is connected to a TPU runtime.149Keras will then automatically leverage JAX’s TPU support,150allowing your model to train efficiently on TPU hardware without further code changes.151"""152153import os154155os.environ["KERAS_BACKEND"] = "jax"156gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")157158"""159## Additional Examples160161### Generation162163Here’s an example using Llama: Loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend.164"""165import os166167os.environ["KERAS_BACKEND"] = "jax"168169from keras_hub.models import Llama3CausalLM170171# Get the model172causal_lm = Llama3CausalLM.from_preset("hf://NousResearch/Hermes-2-Pro-Llama-3-8B")173174prompts = [175"""<|im_start|>system176You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>177<|im_start|>user178Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>179<|im_start|>assistant""",180]181182# Generate from the model183causal_lm.generate(prompts, max_length=30)[0]184185"""186## Comparing to Transformers187188In the following table, we have compiled a detailed comparison of HuggingFace's Transformers library with KerasHub:189190| Feature | HF Transformers | KerasHub |191|----------------------------|-------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|192| Frameworks supported | PyTorch | JAX, PyTorch, TensorFlow |193| Trainer | HF Trainer | Keras `model.fit(...)` — supports nearly all features such as distributed training, learning rate scheduling, optimizer selection, etc. |194| Tokenizers | `AutoTokenizer` | [KerasHub Tokenizers](https://keras.io/keras_hub/api/tokenizers/) |195| Autoclass | `auto` keyword | KerasHub automatically [detects task-specific classes](https://x.com/fchollet/status/1922719664859381922) |196| Model loading | `AutoModel.from_pretrained()` | `keras_hub.models.<Task>.from_preset()`<br><br>KerasHub uses task-specific classes (e.g., `CausalLM`, `Classifier`, `Backbone`) with a `from_preset()` method to load pretrained models, analogous to HuggingFace’s method.<br><br>Supports HF URLs, Kaggle URLs, and local directories |197| Model saving | `model.save_pretrained()`<br>`tokenizer.save_pretrained()` | `model.save_to_preset()` — saves the model (including tokenizer/preprocessor) into a local directory (preset). All components needed for reloading or uploading are saved. |198| Model uploading | Uploading weights to HF platform | [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)<br>[Keras on Hugging Face](https://huggingface.co/keras) |199| Weights file sharding | Weights file sharding | Large model weights are sharded for efficient upload/download |200| PEFT | Uses [HuggingFace PEFT](https://github.com/huggingface/peft) | Built-in LoRA support:<br>`backbone.enable_lora(rank=n)`<br>`backbone.save_lora_weights(filepath)`<br>`backbone.load_lora_weights(filepath)` |201| Core model abstractions | `PreTrainedModel`, `AutoModel`, task-specific models | `Backbone`, `Preprocessor`, `Task` |202| Model configs | `PretrainedConfig`: Base class for model configurations | Configurations stored as multiple JSON files in preset directory: `config.json`, `preprocessor.json`, `task.json`, `tokenizer.json`, etc. |203| Preprocessing | Tokenizers/preprocessors often handled separately, then passed to the model | Built into task-specific models |204| Mixed precision training | Via training arguments | Keras global policy setting |205| Compatibility with SafeTensors | Default weights format | Of the 770k+ SafeTensors models on HF, those with a matching architecture in KerasHub can be loaded using `keras_hub.models.X.from_preset()` |206207208Go try loading other model weights! You can find more options on HuggingFace209and use them with `from_preset("hf://<namespace>/<model-name>")`.210211Happy experimenting!212"""213214215