Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/getting_started.py
3293 views
1
"""
2
Title: Getting Started with KerasHub
3
Author: [Matthew Watson](https://github.com/mattdangerw/), [Jonathan Bischof](https://github.com/jbischof)
4
Date created: 2022/12/15
5
Last modified: 2024/10/17
6
Description: An introduction to the KerasHub API.
7
Accelerator: GPU
8
"""
9
10
"""
11
**KerasHub** is a pretrained modeling library that aims to be simple, flexible, and fast.
12
The library provides [Keras 3](https://keras.io/keras_3/) implementations of popular
13
model architectures, paired with a collection of pretrained checkpoints available on
14
[Kaggle](https://www.kaggle.com/organizations/keras/models). Models can be used for both
15
training and inference on any of the TensorFlow, Jax, and Torch backends.
16
17
KerasHub is an extension of the core Keras API; KerasHub components are provided as
18
`keras.Layer`s and `keras.Model`s. If you are familiar with Keras, congratulations! You
19
already understand most of KerasHub.
20
21
This guide is meant to be an accessible introduction to the entire library. We will start
22
by using high-level APIs to classify images and generate text, then progressively show
23
deeper customization of models and training. Throughout the guide, we use Professor Keras,
24
the official Keras mascot, as a visual reference for the complexity of the material:
25
26
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png)
27
28
As always, we'll keep our Keras guides focused on real-world code examples. You can play
29
with the code here at any time by clicking the Colab link at the top of the guide.
30
"""
31
32
"""
33
## Installation and Setup
34
"""
35
36
"""
37
To begin, let's install keras-hub. The library is available on PyPI, so we can simply
38
install it with pip.
39
"""
40
41
"""shell
42
pip install --upgrade --quiet keras-hub keras
43
"""
44
45
"""
46
Keras 3 was built to work on top of TensorFlow, Jax, and Torch backends. You should
47
specify the backend first thing when writing Keras code, before any library imports. We
48
will use the Jax backend for this guide, but you can use `torch` or `tensorflow` without
49
changing a single line in the rest of this guide. That's the power of Keras 3!
50
51
We will also set `XLA_PYTHON_CLIENT_MEM_FRACTION`, which frees up the whole GPU for
52
Jax to use from the start.
53
"""
54
55
import os
56
57
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
58
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
59
60
"""
61
Lastly, we need to do some extra setup to access the models used in this guide. Many
62
popular open LLMs, such as Gemma from Google and Llama from Meta, require accepting
63
a community license before accessing the model weights. We will be using Gemma in this
64
guide, so we can follow the following steps:
65
66
1. Go to the [Gemma 2](https://www.kaggle.com/models/keras/gemma2) model page, and accept
67
the license at the banner at the top.
68
2. Generate an Kaggle API key by going to [Kaggle settings](https://www.kaggle.com/settings)
69
and clicking "Create New Token" button under the "API" section.
70
3. Inside your colab notebook, click on the key icon on the left hand toolbar. Add two
71
secrets: `KAGGLE_USERNAME` with your username, and `KAGGLE_KEY` with the API key you just
72
created. Make these secrets visible to the notebook you are running.
73
"""
74
75
"""
76
## API Quickstart
77
78
Before we begin, let's take a look at the key classes we will use in the KerasHub library.
79
80
* **Task**: e.g., `keras_hub.models.CausalLM`, `keras_hub.models.ImageClassifier`, and
81
`keras_hub.models.TextClassifier`.
82
* **What it does**: A task maps from raw image, audio, and text inputs to model
83
predictions.
84
* **Why it's important**: A task is the highest-level entry point to the KerasHub API. It
85
encapsulates both preprocessing and modeling into a single, easy-to-use class. Tasks can
86
be used both for fine-tuning and inference.
87
* **Has a**: `backbone` and `preprocessor`.
88
* **Inherits from**: `keras.Model`.
89
* **Backbone**: `keras_hub.models.Backbone`.
90
* **What it does**: Maps preprocessed tensor inputs to the latent space of the model.
91
* **Why it's important**: The backbone encapsulates the architecture and parameters of a
92
pretrained models in a way that is unspecialized to any particular task. A backbone can
93
be combined with arbitrary preprocessing and "head" layers mapping dense features to
94
predictions to accomplish any ML task.
95
* **Inherits from**: `keras.Model`.
96
* **Preprocessor**: e.g.,`keras_hub.models.CausalLMPreprocessor`,
97
`keras_hub.models.ImageClassifierPreprocessor`, and
98
`keras_hub.models.TextClassifierPreprocessor`.
99
* **What it does**: A preprocessor maps from raw image, audio and text inputs to
100
preprocessed tensor inputs.
101
* **Why it's important**: A preprocessing layer encapsulates all tasks specific
102
preprocessing, e.g. image resizing and text tokenization, in a way that can be used
103
standalone to precompute preprocessed inputs. Note that if you are using a high-level
104
task class, this preprocessing is already baked in by default.
105
* **Has a**: `tokenizer`, `audio_converter`, and/or `image_converter`.
106
* **Inherits from**: `keras.layers.Layer`.
107
* **Tokenizer**: `keras_hub.tokenizers.Tokenizer`.
108
* **What it does**: Converts strings to sequences of token ids.
109
* **Why it's important**: The raw bytes of a string are an inefficient representation of
110
text input, so we first map string inputs to integer token ids. This class encapsulated
111
the mapping of strings to ints and the reverse (via the `detokenize()` method).
112
* **Inherits from**: `keras.layers.Layer`.
113
* **ImageConverter**: `keras_hub.layers.ImageConverter`.
114
* **What it does**: Resizes and rescales image input.
115
* **Why it's important**: Image models often need to normalize image inputs to a specific
116
range, or resizing inputs to a specific size. This class encapsulates the image-specific
117
preprocessing.
118
* **Inherits from**: `keras.layers.Layer`.
119
* **AudioConveter**: `keras_hub.layers.AudioConveter`.
120
* **What it does**: Converts raw audio to model ready input.
121
* **Why it's important**: Audio models often need to preprocess raw audio input before
122
passing it to a model, e.g. by computing a spectrogram of the audio signal. This class
123
encapsulates the image specific preprocessing in an easy to use layer.
124
* **Inherits from**: `keras.layers.Layer`.
125
126
All of the classes listed here have a `from_preset()` constructor, which will instantiate
127
the component with weights and state for the given pre-trained model identifier. E.g.
128
`keras_hub.tokenizers.Tokenizer.from_preset("gemma2_2b_en")` will create a layer that
129
tokenizes text using a Gemma2 tokenizer vocabulary.
130
131
The figure below shows how all these core classes interact. Arrow indicate composition
132
not inheritance (e.g., a task *has a* backbone).
133
134
![png](/img/guides/getting_started/class-diagram.png)
135
"""
136
137
"""
138
## Classify an image
139
140
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png)
141
"""
142
143
"""
144
Enough setup! Let's have some fun with pre-trained models. Let's load a test image of a
145
California Quail and classify it.
146
"""
147
148
import keras
149
import numpy as np
150
import matplotlib.pyplot as plt
151
152
image_url = "https://upload.wikimedia.org/wikipedia/commons/a/aa/California_quail.jpg"
153
image_path = keras.utils.get_file(origin=image_url)
154
image = keras.utils.load_img(image_path)
155
plt.imshow(image)
156
157
"""
158
We can use a ResNet vision model trained on the ImageNet-1k database. This model will
159
give each input sample and output label from `[0, 1000)`, where each label corresponds to
160
some real word entity, like a "milk can" or a "porcupine." The dataset actually has a
161
specific label for quail, at index 85. Let's download the model and predict a label.
162
"""
163
164
import keras_hub
165
166
image_classifier = keras_hub.models.ImageClassifier.from_preset(
167
"resnet_50_imagenet",
168
activation="softmax",
169
)
170
batch = np.array([image])
171
image_classifier.preprocessor.image_size = (224, 224)
172
preds = image_classifier.predict(batch)
173
preds.shape
174
175
"""
176
These ImageNet labels aren't a particularly "human readable," so we can use a built-in
177
utility function to decode the predictions to a set of class names.
178
"""
179
180
keras_hub.utils.decode_imagenet_predictions(preds)
181
182
"""
183
Looking good! The model weights successfully downloaded, and we predicted the
184
correct classification label for our quail image with near certainty.
185
186
This was our first example of the high-level **task** API mentioned in the API quickstart
187
above. An `keras_hub.models.ImageClassifier` is a task for classifying images, and can be
188
used with a number of different model architectures (ResNet, VGG, MobileNet, etc). You
189
can view the full list of models shipped directly by the Keras team on
190
[Kaggle](https://www.kaggle.com/organizations/keras/models).
191
192
A task is just a subclass of `keras.Model` — you can use `fit()`, `compile()`, and
193
`save()` on our `classifier` object same as any other model. But tasks come with a few
194
extras provided by the KerasHub library. The first and most important is `from_preset()`,
195
a special constructor you will see on many classes in KerasHub.
196
197
A **preset** is a directory of model state. It defines both the architecture we should
198
load and the pretrained weights that go with it. `from_preset()` allows us to load
199
**preset** directories from a number of different locations:
200
201
- A local directory.
202
- The Kaggle Model hub.
203
- The HuggingFace model hub.
204
205
You can take a look at the `keras_hub.models.ImageClassifier.from_preset` docs to better
206
understand all the options when constructing a Keras model from a preset.
207
208
All tasks use two main sub-objects. A `keras_hub.models.Backbone` and a
209
`keras_hub.layers.Preprocessor`. You might be familiar already with the term **backbone**
210
from computer vision, where it is often used to describe a feature extractor network that
211
maps images to a latent space. A KerasHub backbone is this concept generalized, we use it
212
to refer to any pretrained model without a task-specific head. That is, a KerasHub
213
backbone maps raw images, audio and text (or a combination of these inputs) to a
214
pretrained model's latent space. We can then map this latent space to any number of task
215
specific outputs, depending on what we are trying to do with the model.
216
217
A **preprocessor** is just a Keras layer that does all the preprocessing for a specific
218
task. In our case, preprocessing with will resize our input image and rescale it to the
219
range `[0, 1]` using some ImageNet specific mean and variance data. Let's call our
220
task's preprocessor and backbone in succession to see what happens to our input shape.
221
"""
222
223
print("Raw input shape:", batch.shape)
224
resized_batch = image_classifier.preprocessor(batch)
225
print("Preprocessed input shape:", resized_batch.shape)
226
hidden_states = image_classifier.backbone(resized_batch)
227
print("Latent space shape:", hidden_states.shape)
228
229
"""
230
Our raw image is rescaled to `(224, 224)` during preprocessing and finally
231
downscaled to a `(7, 7)` image of 2048 feature vectors — the latent space of the
232
ResNet model. Note that ResNet can actually handle images of arbitrary sizes,
233
though performance will eventually fall off if your image is very different
234
sized than the pretrained data. If you'd like to disable the resizing in the
235
preprocessing layer, you can run `image_classifier.preprocessor.image_size = None`.
236
237
If you are ever wondering the exact structure of the task you loaded, you can
238
use `model.summary()` same as any Keras model. The model summary for tasks will
239
included extra information on model preprocessing.
240
"""
241
242
image_classifier.summary()
243
244
"""
245
## Generate text with an LLM
246
247
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png)
248
"""
249
250
"""
251
Next up, let's try working with and generating text. The task we can use when generating
252
text is `keras_hub.models.CausalLM` (where LM is short for **L**anguage **M**odel). Let's
253
download the 2 billion parameter Gemma 2 model and try it out.
254
255
Since this is about 100x larger model than the ResNet model we just downloaded, we need to be
256
a little more careful about our GPU memory usage. We can use a half-precision type to
257
load each parameter of our ~2.5 billion as a two-byte float instead of four. To do this
258
we can pass `dtype` to the `from_preset()` constructor. `from_preset()` will forward any
259
kwargs to the main constructor for the class, so you can pass kwargs that work on all
260
Keras layers like `dtype`, `trainable`, and `name`.
261
"""
262
263
causal_lm = keras_hub.models.CausalLM.from_preset(
264
"gemma2_instruct_2b_en",
265
dtype="bfloat16",
266
)
267
268
"""
269
The model we just loaded was an instruction-tuned version of Gemma, which means the model
270
was further fine-tuned for chat. We can take advantage of these capabilities as long as
271
we stick to the particular template for text used when training the model. These special
272
tokens vary per model and can be hard to track, the [Kaggle model
273
page](https://www.kaggle.com/models/keras/gemma2/) will contain details such as this.
274
275
`CausalLM` comes with an extra function called `generate()` which can be used generate
276
predicted tokens in a loop and decode them as a string.
277
"""
278
279
template = "<start_of_turn>user\n{question}<end_of_turn>\n<start_of_turn>model"
280
281
question = """Write a python program to generate the first 1000 prime numbers.
282
Just show the actual code."""
283
print(causal_lm.generate(template.format(question=question), max_length=512))
284
285
"""
286
Note that on the Jax and TensorFlow backends, this `generate()` function is compiled, so
287
the second time you call for the same `max_length`, it will actually be much faster.
288
KerasHub will use Jax and TensorFlow to compute an optimized version of the generation
289
computational graph that can be reused.
290
"""
291
292
question = "Share a very simple brownie recipe."
293
print(causal_lm.generate(template.format(question=question), max_length=512))
294
295
"""
296
As with our image classifier, we can use model summary to see the details of our task
297
setup, including preprocessing.
298
"""
299
300
causal_lm.summary()
301
302
"""
303
Our text preprocessing includes a tokenizer, which is how all KerasHub models handle
304
input text. Let's try using it directly to get a better sense of how it works. All
305
tokenizers include `tokenize()` and `detokenize()` methods, to map strings to integer
306
sequences and integer sequences to strings. Directly calling the layer with
307
`tokenizer(inputs)` is equivalent to calling `tokenizer.tokenize(inputs)`.
308
"""
309
310
tokenizer = causal_lm.preprocessor.tokenizer
311
tokens_ids = tokenizer.tokenize("The quick brown fox jumps over the lazy dog.")
312
print(tokens_ids)
313
string = tokenizer.detokenize(tokens_ids)
314
print(string)
315
316
"""
317
The `generate()` function for `CausalLM` models involved a sampling step. The Gemma model
318
will be called once for each token we want to generate, and return a probability
319
distribution over all tokens. This distribution is then sampled to choose the next token
320
in the sequence.
321
322
For Gemma models, we default to greedy sampling, meaning we simply pick the most likely
323
output from the model at each step. But we can actually control this process with an
324
extra `sampler` argument to the standard `compile` function on all Keras models. Let's
325
try it out.
326
"""
327
328
causal_lm.compile(
329
sampler=keras_hub.samplers.TopKSampler(k=10, temperature=2.0),
330
)
331
332
question = "Share a very simple brownie recipe."
333
print(causal_lm.generate(template.format(question=question), max_length=512))
334
335
"""
336
Here we used a Top-K sampler, meaning we will randomly sample the partial distribution formed
337
by looking at just the top 10 predicted tokens at each time step. We also pass a `temperature` of 2,
338
which flattens our predicted distribution before we sample.
339
340
The net effect is that we will explore our model's distribution much more broadly each
341
time we generate output. Generation will now be a random process, each time we re-run
342
generate we will get a different result. We can note that the results feel "looser" than
343
greedy search — more minor mistakes and a less consistent tone.
344
345
You can look at all the samplers Keras supports at [keras_hub.samplers](https://keras.io/api/keras_hub/samplers/).
346
347
Let's free up the memory from our large Gemma model before we jump to the next section.
348
"""
349
350
del causal_lm
351
352
"""
353
## Fine-tune and publish an image classifier
354
355
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png)
356
"""
357
358
"""
359
Now that we've tried running inference for both images and text, let's try running
360
training. We will take our ResNet image classifier from earlier and fine-tune it on
361
simple cats vs dogs dataset. We can start by downloading and extracting the data.
362
"""
363
364
import pathlib
365
366
extract_dir = keras.utils.get_file(
367
"cats_vs_dogs",
368
"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip",
369
extract=True,
370
)
371
data_dir = pathlib.Path(extract_dir) / "PetImages"
372
373
"""
374
When working with lots of real-world image data, corrupted images are a common occurrence.
375
Let's filter out badly-encoded images that do not feature the string "JFIF" in their
376
header.
377
"""
378
379
num_skipped = 0
380
381
for path in data_dir.rglob("*.jpg"):
382
with open(path, "rb") as file:
383
is_jfif = b"JFIF" in file.peek(10)
384
if not is_jfif:
385
num_skipped += 1
386
os.remove(path)
387
388
print(f"Deleted {num_skipped} images.")
389
390
"""
391
We can load the dataset with `keras.utils.image_dataset_from_directory`. One important
392
thing to note here is that the `train_ds` and `val_ds` will both be returned as
393
`tf.data.Dataset` objects, including on the `torch` and `jax` backends.
394
395
KerasHub will use [tf.data](https://www.tensorflow.org/guide/data) as the default API for
396
running multi-threaded preprocessing on the CPU. `tf.data` is a powerful API for training
397
input pipelines that can scale up to complex, multi-host training jobs easily. Using it
398
does not restrict your choice of backend, a `tf.data.Dataset` can be as an iterator of
399
regular numpy data and passed to `fit()` on any Keras backend.
400
"""
401
402
train_ds, val_ds = keras.utils.image_dataset_from_directory(
403
data_dir,
404
validation_split=0.2,
405
subset="both",
406
seed=1337,
407
image_size=(256, 256),
408
batch_size=32,
409
)
410
411
"""
412
At its simplest, training our classifier could consist of simply calling `fit()` on our
413
model with our dataset. But to make this example a little more interesting, let's show
414
how to customize preprocessing within a task.
415
416
In the first example, we saw how, by default, the preprocessing for our ResNet model resized
417
and rescaled our input. This preprocessing can be customized when we create our model. We
418
can use Keras' image preprocessing layers to create a `keras.layers.Pipeline` that will
419
rescale, randomly flip, and randomly rotate our input images. These random image
420
augmentations will allow our smaller dataset to function as a larger, more varied one.
421
Let's try it out.
422
"""
423
424
preprocessor = keras.layers.Pipeline(
425
[
426
keras.layers.Rescaling(1.0 / 255),
427
keras.layers.RandomFlip("horizontal"),
428
keras.layers.RandomRotation(0.2),
429
]
430
)
431
432
"""
433
Now that we have created a new layer for preprocessing, we can simply pass it to the
434
`ImageClassifier` during the `from_preset()` constructor. We can also pass
435
`num_classes=2` to match our two labels for "cat" and "dog." When `num_classes` is
436
specified like this, our head weights for the model will be randomly initialized
437
instead of containing the weights for our 1000 class image classification.
438
"""
439
440
image_classifier = keras_hub.models.ImageClassifier.from_preset(
441
"resnet_50_imagenet",
442
activation="softmax",
443
num_classes=2,
444
preprocessor=preprocessor,
445
)
446
447
"""
448
Note that if you want to preprocess your input data outside of Keras, you can simply
449
pass `preprocessor=None` to the task `from_preset()` call. In this case, KerasHub will
450
apply no preprocessing at all, and you are free to preprocess your data with any library
451
or workflow before passing your data to `fit()`.
452
453
Next, we can compile our model for fine-tuning. A KerasHub task is just a regular
454
`keras.Model` with some extra functionality, so we can `compile()` as normal for a
455
classification task.
456
"""
457
458
image_classifier.compile(
459
optimizer=keras.optimizers.Adam(1e-4),
460
loss="sparse_categorical_crossentropy",
461
metrics=["accuracy"],
462
)
463
464
"""
465
With that, we can simply run `fit()`. The image classifier will automatically apply our
466
preprocessing to each batch when training the model.
467
"""
468
469
image_classifier.fit(
470
train_ds,
471
validation_data=val_ds,
472
epochs=3,
473
)
474
475
"""
476
After three epochs of data, we achieve 99% accuracy on our cats vs dogs
477
validation dataset. This is unsurprising, given that the ImageNet pretrained weights we began
478
with could already classify some breeds of cats and dogs individually.
479
480
Now that we have a fine-tuned model let's try saving it. You can create a new saved preset with a
481
fine-tuned model for any task simply by running `task.save_to_preset()`.
482
"""
483
484
image_classifier.save_to_preset("cats_vs_dogs")
485
486
"""
487
One of the most powerful features of KerasHub is the ability upload models to Kaggle or
488
Huggingface models hub and share them with others. `keras_hub.upload_preset` allows you
489
to upload a saved preset.
490
491
In this case, we will upload to Kaggle. We have already authenticated with Kaggle to,
492
download the Gemma model earlier. Running the following cell well upload a new model
493
to Kaggle.
494
"""
495
496
from google.colab import userdata
497
498
username = userdata.get("KAGGLE_USERNAME")
499
keras_hub.upload_preset(
500
f"kaggle://{username}/resnet/keras/cats_vs_dogs",
501
"cats_vs_dogs",
502
)
503
504
"""
505
Let's take a look at a test image from our dataset.
506
"""
507
508
image = keras.utils.load_img(data_dir / "Cat" / "6779.jpg")
509
plt.imshow(image)
510
511
"""
512
If we wait for a few minutes for our model upload to finish processing on the Kaggle
513
side, we can go ahead and download the model we just created and use it to classify this
514
test image.
515
"""
516
517
image_classifier = keras_hub.models.ImageClassifier.from_preset(
518
f"kaggle://{username}/resnet/keras/cats_vs_dogs",
519
)
520
print(image_classifier.predict(np.array([image])))
521
522
"""
523
Congratulations on uploading your first model with KerasHub! If you want to share your
524
work with others, you can go to the model link printed out when we uploaded the model, and
525
turn the model public in settings.
526
527
Let's delete this model to free up memory before we move on to our final example for this
528
guide.
529
"""
530
531
del image_classifier
532
533
"""
534
## Building a custom text classifier
535
536
![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_expert.png)
537
"""
538
539
"""
540
As a final example for this getting started guide, let's take a look at how we can build
541
custom models from lower-level Keras and KerasHub components. We will build a text
542
classifier to classify movie reviews in the IMDb dataset as either positive or negative.
543
544
Let's download the dataset.
545
"""
546
547
extract_dir = keras.utils.get_file(
548
"imdb_reviews",
549
origin="https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz",
550
extract=True,
551
)
552
data_dir = pathlib.Path(extract_dir) / "aclImdb"
553
554
"""
555
The IMDb dataset contrains a large amount of unlabeled movie reviews. We don't need those
556
here, we can simply delete them.
557
"""
558
559
import shutil
560
561
shutil.rmtree(data_dir / "train" / "unsup")
562
563
"""
564
Next up, we can load our data with `keras.utils.text_dataset_from_directory`. As with our
565
image dataset creation above, the returned datasets will be `tf.data.Dataset` objects.
566
"""
567
568
raw_train_ds = keras.utils.text_dataset_from_directory(
569
data_dir / "train",
570
batch_size=2,
571
)
572
raw_val_ds = keras.utils.text_dataset_from_directory(
573
data_dir / "test",
574
batch_size=2,
575
)
576
577
"""
578
KerasHub is designed to be a layered API. At the top-most level, tasks aim to make it
579
easy to quickly tackle a problem. We could keep using the task API here, and create a
580
`keras_hub.models.TextClassifer` for a text classification model like BERT, and fine-tune
581
it in 10 or so lines of code.
582
583
Instead, to make our final example a little more interesting, let's show how we can use
584
lower-level API components to do something that isn't directly baked in to the library.
585
We will take the Gemma 2 model we used earlier, which is usually used for generating text,
586
and modify it to output classification predictions.
587
588
A common approach for classifying with a generative model would keep using it in a generative
589
context, by prompting it with the review and a question (`"Is this review positive or negative?"`).
590
But making an actual classifier is more useful if you want an actual probability score associated
591
with your labels.
592
593
Instead of loading the Gemma 2 model through the `CausalLM` task, we can load two
594
lower-level components: a **backbone** and a **tokenizer**. Much like the task classes we have
595
used so far, `keras_hub.models.Backbone` and `keras_hub.tokenizers.Tokenizer` both have a
596
`from_preset()` constructor for loading pretrained models. If you are running this code,
597
you will note you don't have to wait for a download as we use the model a second time,
598
the weights files are cached locally the first time we use the model.
599
"""
600
601
tokenizer = keras_hub.tokenizers.Tokenizer.from_preset(
602
"gemma2_instruct_2b_en",
603
)
604
backbone = keras_hub.models.Backbone.from_preset(
605
"gemma2_instruct_2b_en",
606
)
607
608
"""
609
We saw what the tokenizer does in the second example of this guide. We can use it to map
610
from string inputs to token ids in a way that matches the pretrained weights of the Gemma
611
model.
612
613
The backbone will map from a sequence of token ids to a sequence of embedded tokens in
614
the latent space of the model. We can use this rich representation to build a classifier.
615
616
Let's start by defining a custom preprocessing routine. `keras_hub.layers` contains a
617
collection of modeling and preprocessing layers, included some layers for token
618
preprocessing. We can use `keras_hub.layers.StartEndPacker`, which will append a special
619
start token to the beginning of each review, a special end token to the end, and finally
620
truncate or pad each review to a fixed length.
621
622
If we combine this with our `tokenizer`, we can build a preprocessing function that will
623
output batches of token ids with shape `(batch_size, sequence_length)`. We should also
624
output a padding mask that marks which tokens are padding tokens, so we can later exclude
625
these positions from our Transformer's attention computation. Most Transformer backbones
626
in KerasNLP take in a `"padding_mask"` input.
627
"""
628
629
packer = keras_hub.layers.StartEndPacker(
630
start_value=tokenizer.start_token_id,
631
end_value=tokenizer.end_token_id,
632
pad_value=tokenizer.pad_token_id,
633
sequence_length=None,
634
)
635
636
637
def preprocess(x, y=None, sequence_length=256):
638
x = tokenizer(x)
639
x = packer(x, sequence_length=sequence_length)
640
x = {
641
"token_ids": x,
642
"padding_mask": x != tokenizer.pad_token_id,
643
}
644
return keras.utils.pack_x_y_sample_weight(x, y)
645
646
647
"""
648
With our preprocessing defined, we can simply use `tf.data.Dataset.map` to apply our
649
preprocessing to our input data.
650
"""
651
652
train_ds = raw_train_ds.map(preprocess, num_parallel_calls=16)
653
val_ds = raw_val_ds.map(preprocess, num_parallel_calls=16)
654
next(iter(train_ds))
655
656
"""
657
Running fine-tuning on a 2.5 billion parameter model is quite expensive compared to the
658
image classifier we trained earlier, for the simple reason that this model is 100x the
659
size of ResNet! To speed things up a bit, let's reduce the size of our training data to a
660
tenth of the original size. Of course, this is leaving some performance on the table
661
compared to full training, but it will keep things running quickly for our guide.
662
"""
663
664
train_ds = train_ds.take(1000)
665
val_ds = val_ds.take(1000)
666
667
"""
668
Next, we need to attach a classification head to our backbone model. In general, text
669
transformer backbones will output a tensor with shape
670
`(batch_size, sequence_length, hidden_dim)`. The main thing we will need to
671
classify with this input is to pool on the sequence dimension so we have a single
672
feature vector per input example.
673
674
Since the Gemma model is a generative model, information only passed from left to right
675
in the sequence. The only token representation that can "see" the entire movie review
676
input is the final token in each review. We can write a simple pooling layer to do this —
677
we will simply grab the last non-padding position of each input sequence. There's no special
678
process to writing a layer like this, we can use Keras and `keras.ops` normally.
679
"""
680
681
from keras import ops
682
683
684
class LastTokenPooler(keras.layers.Layer):
685
def call(self, inputs, padding_mask):
686
end_positions = ops.sum(padding_mask, axis=1, keepdims=True) - 1
687
end_positions = ops.cast(end_positions, "int")[:, :, None]
688
outputs = ops.take_along_axis(inputs, end_positions, axis=1)
689
return ops.squeeze(outputs, axis=1)
690
691
692
"""
693
With this pooling layer, we are ready to write our Gemma classifier. All task and backbone
694
models in KerasHub are [functional](https://keras.io/guides/functional_api/) models, so
695
we can easily manipulate the model structure. We will call our backbone on our inputs, add
696
our new pooling layer, and finally add a small feedforward network with a `"relu"` activation
697
in the middle. Let's try it out.
698
"""
699
700
inputs = backbone.input
701
x = backbone(inputs)
702
x = LastTokenPooler(
703
name="pooler",
704
)(x, inputs["padding_mask"])
705
x = keras.layers.Dense(
706
2048,
707
activation="relu",
708
name="pooled_dense",
709
)(x)
710
x = keras.layers.Dropout(
711
0.1,
712
name="output_dropout",
713
)(x)
714
outputs = keras.layers.Dense(
715
2,
716
activation="softmax",
717
name="output_dense",
718
)(x)
719
text_classifier = keras.Model(inputs, outputs)
720
text_classifier.summary()
721
722
"""
723
Before we train, there is one last trick we should employ to make this code run on free
724
tier colab GPUs. We can see from our model summary our model takes up almost 10 gigabytes
725
of space. An optimizer will need to make multiple copies of each parameter during
726
training, taking the total space of our model during training close to 30 or 40
727
gigabytes.
728
729
This would OOM many GPUs. A useful trick we can employ is to enable LoRA on our
730
backbone. LoRA is an approach which freezes the entire model, and only trains a low-parameter
731
decomposition of large weight matrices. You can read more about LoRA in this [Keras
732
example](https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/).
733
Let's try enabling it and re-printing our summary.
734
"""
735
736
backbone.enable_lora(4)
737
text_classifier.summary()
738
739
"""
740
After enabling LoRA, our model goes from 10GB of traininable parameters to just 20MB.
741
That means the space used by optimizer variables will no longer be a concern.
742
743
With all that set up, we can compile and train our model as normal.
744
"""
745
746
text_classifier.compile(
747
optimizer=keras.optimizers.Adam(5e-5),
748
loss="sparse_categorical_crossentropy",
749
metrics=["accuracy"],
750
)
751
text_classifier.fit(
752
train_ds,
753
validation_data=val_ds,
754
)
755
756
"""
757
We are able to achieve over ~93% accuracy on the movie review sentiment
758
classification problem. This is not bad, given that we only used a 10th of our
759
original dataset to train.
760
761
Taken together, the `backbone` and `tokenizer` we created in this example
762
allowed us access the full power of pretrained Gemma checkpoints, without
763
restricting what we could do with them. This is a central aim of the KerasHub
764
API. Simple workflows should be easy, and as you go deeper, you gain access to a
765
deeply customizable set of building blocks.
766
"""
767
768
"""
769
## Going further
770
771
This is just scratching the surface of what you can do with the KerasHub.
772
773
This guide shows a few of the high-level tasks that we ship with the KerasHub library,
774
but there are many tasks we did not cover here. Try [generating images with Stable
775
Diffusion](https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/), for
776
example.
777
778
The most significant advantage of KerasHub is it gives you the flexibility to combine pre-trained
779
building blocks with the full power of Keras 3. You can train large LLMs on TPUs with model
780
parallelism with the [keras.distribution](https://keras.io/guides/distribution/) API. You can
781
quantize models with Keras' [quatize
782
method](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transfo
783
rmer/). You can write custom training loops and even mix in direct Jax, Torch, or
784
Tensorflow calls.
785
786
See [keras.io/keras_hub](https://keras.io/keras_hub/) for a full list of guides and
787
examples to continue digging into the library.
788
"""
789
790