Path: blob/master/guides/keras_hub/function_calling_with_keras_hub.py
3293 views
"""1Title: Function Calling with KerasHub models2Author: [Laxmareddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)3Date created: 2025/07/084Last modified: 2025/07/105Description: A guide to using the function calling feature in KerasHub with Gemma 3 and Mistral.6Accelerator: GPU7"""89"""10## Introduction1112Tool calling is a powerful new feature in modern large language models that allows them to use external tools, such as Python functions, to answer questions and perform actions. Instead of just generating text, a tool-calling model can generate code that calls a function you've provided, allowing it to interact with the real world, access live data, and perform complex calculations.1314In this guide, we'll walk you through a simple example of tool calling with the Gemma 3 and Mistral models and KerasHub. We'll show you how to:15161. Define a tool (a Python function).172. Tell the models about the tool.183. Use the model to generate code that calls the tool.194. Execute the code and feed the result back to the model.205. Get a final, natural-language response from the model.2122Let's get started!23"""2425"""26## Setup2728First, let's import the necessary libraries and configure our environment. We'll be using KerasHub to download and run the language models, and we'll need to authenticate with Kaggle to access the model weights.29"""3031import os32import json33import random34import string35import re36import ast37import io38import sys39import contextlib4041# Set backend before importing Keras42os.environ["KERAS_BACKEND"] = "jax"4344import keras45import keras_hub46import kagglehub47import numpy as np4849# Constants50USD_TO_EUR_RATE = 0.855152# Set the default dtype policy to bfloat16 for improved performance and reduced memory usage on supported hardware (e.g., TPUs, some GPUs)53keras.config.set_dtype_policy("bfloat16")5455# Authenticate with Kaggle56# In Google Colab, you can set KAGGLE_USERNAME and KAGGLE_KEY as secrets,57# and kagglehub.login() will automatically detect and use them:58# kagglehub.login()5960"""61## Loading the Model6263Next, we'll load the Gemma 3 model from KerasHub. We're using the `gemma3_instruct_4b` preset, which is a version of the model that has been specifically fine-tuned for instruction following and tool calling.64"""6566try:67gemma = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_4b")68print("✅ Gemma 3 model loaded successfully")69except Exception as e:70print(f"❌ Error loading Gemma 3 model: {e}")71print("Please ensure you have the correct model preset and sufficient resources.")72raise7374"""75## Defining a Tool7677Now, let's define a simple tool that we want our model to be able to use. For this example, we'll create a Python function called `convert` that can convert one currency to another.78"""798081def convert(amount, currency, new_currency):82"""Convert the currency with the latest exchange rate8384Args:85amount: The amount of currency to convert86currency: The currency to convert from87new_currency: The currency to convert to88"""89# Input validation90if amount < 0:91raise ValueError("Amount cannot be negative")9293if not isinstance(currency, str) or not isinstance(new_currency, str):94raise ValueError("Currency codes must be strings")9596# Normalize currency codes to uppercase to handle model-generated lowercase codes97currency = currency.upper().strip()98new_currency = new_currency.upper().strip()99100# In a real application, this function would call an API to get the latest101# exchange rate. For this example, we'll just use a fixed rate.102if currency == "USD" and new_currency == "EUR":103return amount * USD_TO_EUR_RATE104elif currency == "EUR" and new_currency == "USD":105return amount / USD_TO_EUR_RATE106else:107raise NotImplementedError(108f"Currency conversion from {currency} to {new_currency} is not supported."109)110111112"""113## Telling the Model About the Tool114115Now that we have a tool, we need to tell the Gemma 3 model about it. We do this by providing a carefully crafted prompt that includes:1161171. A description of the tool calling process.1182. The Python code for the tool, including its function signature and docstring.1193. The user's question.120121Here's the prompt we'll use:122"""123124message = '''125<start_of_turn>user126At each turn, if you decide to invoke any of the function(s), it should be wrapped with ```tool_code```. The python methods described below are imported and available, you can only use defined methods and must not reimplement them. The generated code should be readable and efficient. I will provide the response wrapped in ```tool_output```, use it to call more tools or generate a helpful, friendly response. When using a ```tool_call``` think step by step why and how it should be used.127128The following Python methods are available:129130```python131def convert(amount, currency, new_currency):132"""Convert the currency with the latest exchange rate133134Args:135amount: The amount of currency to convert136currency: The currency to convert from137new_currency: The currency to convert to138"""139```140141User: What is $200,000 in EUR?<end_of_turn>142<start_of_turn>model143'''144145"""146## Generating the Tool Call147148Now, let's pass this prompt to the model and see what it generates.149"""150151print(gemma.generate(message))152153"""154As you can see, the model has correctly identified that it can use the `convert` function to answer the question, and it has generated the corresponding Python code.155"""156157"""158## Executing the Tool Call and Getting a Final Answer159160In a real application, you would now take this generated code, execute it, and feed the result back to the model. Let's create a practical example that shows how to do this:161"""162163# First, let's get the model's response164response = gemma.generate(message)165print("Model's response:")166print(response)167168169# Extract the tool call from the response170def extract_tool_call(response_text):171"""Extract tool call from the model's response."""172tool_call_pattern = r"```tool_code\s*\n(.*?)\n```"173match = re.search(tool_call_pattern, response_text, re.DOTALL)174if match:175return match.group(1).strip()176return None177178179def capture_code_output(code_string, globals_dict=None, locals_dict=None):180"""181Executes Python code and captures any stdout output.182183184This function uses eval() and exec() which can execute arbitrary code.185NEVER use this function with untrusted code in production environments.186Always validate and sanitize code from LLMs before execution.187188Args:189code_string (str): The code to execute (expression or statements).190globals_dict (dict, optional): Global variables for execution.191locals_dict (dict, optional): Local variables for execution.192193Returns:194The captured stdout output if any, otherwise the return value of the expression,195or None if neither.196"""197if globals_dict is None:198globals_dict = {}199if locals_dict is None:200locals_dict = globals_dict201202output = io.StringIO()203try:204with contextlib.redirect_stdout(output):205try:206# Try to evaluate as an expression207result = eval(code_string, globals_dict, locals_dict)208except SyntaxError:209# If not an expression, execute as statements210exec(code_string, globals_dict, locals_dict)211result = None212except Exception as e:213return f"Error during code execution: {e}"214215stdout_output = output.getvalue()216if stdout_output.strip():217return stdout_output218return result219220221# Extract and execute the tool call222tool_code = extract_tool_call(response)223if tool_code:224print(f"\nExtracted tool call: {tool_code}")225try:226local_vars = {"convert": convert}227tool_result = capture_code_output(tool_code, globals_dict=local_vars)228print(f"Tool execution result: {tool_result}")229230# Create the next message with the tool result231message_with_result = f'''232<start_of_turn>user233At each turn, if you decide to invoke any of the function(s), it should be wrapped with ```tool_code```. The python methods described below are imported and available, you can only use defined methods and must not reimplement them. The generated code should be readable and efficient. I will provide the response wrapped in ```tool_output```, use it to call more tools or generate a helpful, friendly response. When using a ```tool_call``` think step by step why and how it should be used.234235The following Python methods are available:236237```python238def convert(amount, currency, new_currency):239"""Convert the currency with the latest exchange rate240241Args:242amount: The amount of currency to convert243currency: The currency to convert from244new_currency: The currency to convert to245"""246```247248User: What is $200,000 in EUR?<end_of_turn>249<start_of_turn>model250```tool_code251print(convert(200000, "USD", "EUR"))252```<end_of_turn>253<start_of_turn>user254```tool_output255{tool_result}256```257<end_of_turn>258<start_of_turn>model259'''260261# Get the final response262final_response = gemma.generate(message_with_result)263print("\nFinal response:")264print(final_response)265266except Exception as e:267print(f"Error executing tool call: {e}")268else:269print("No tool call found in the response")270271"""272## Automated Tool Call Execution Loop273274Let's create a more sophisticated example that shows how to automatically handle multiple tool calls in a conversation:275"""276277278def automated_tool_calling_example():279"""Demonstrate automated tool calling with a conversation loop."""280281conversation_history = []282max_turns = 5283284# Initial user message285user_message = "What is $500 in EUR, and then what is that amount in USD?"286287# Define base prompt outside the loop for better performance288base_prompt = f'''289<start_of_turn>user290At each turn, if you decide to invoke any of the function(s), it should be wrapped with ```tool_code```. The python methods described below are imported and available, you can only use defined methods and must not reimplement them. The generated code should be readable and efficient. I will provide the response wrapped in ```tool_output```, use it to call more tools or generate a helpful, friendly response. When using a ```tool_call``` think step by step why and how it should be used.291292The following Python methods are available:293294```python295def convert(amount, currency, new_currency):296"""Convert the currency with the latest exchange rate297298Args:299amount: The amount of currency to convert300currency: The currency to convert from301new_currency: The currency to convert to302"""303```304305User: {user_message}<end_of_turn>306<start_of_turn>model307'''308309for turn in range(max_turns):310print(f"\n--- Turn {turn + 1} ---")311312# Build conversation context by appending history to base prompt313context = base_prompt314for hist in conversation_history:315context += hist + "\n"316317# Get model response318response = gemma.generate(context, strip_prompt=True)319print(f"Model response: {response}")320321# Extract tool call322tool_code = extract_tool_call(response)323324if tool_code:325print(f"Executing: {tool_code}")326try:327local_vars = {"convert": convert}328tool_result = capture_code_output(tool_code, globals_dict=local_vars)329conversation_history.append(330f"```tool_code\n{tool_code}\n```<end_of_turn>"331)332conversation_history.append(333f"<start_of_turn>user\n```tool_output\n{tool_result}\n```<end_of_turn>"334)335conversation_history.append(f"<start_of_turn>model\n")336print(f"Tool result: {tool_result}")337except Exception as e:338print(f"Error executing tool: {e}")339break340else:341print("No tool call found - conversation complete")342conversation_history.append(response)343break344345print("\n--- Final Conversation ---")346print(context)347for hist in conversation_history:348print(hist)349350351# Run the automated example352print("Running automated tool calling example:")353automated_tool_calling_example()354355"""356## Mistral357358Mistral differs from Gemma in its approach to tool calling, as it requires a specific format and defines special control tokens for this purpose. This JSON-based syntax for tool calling is also adopted by other models, such as Qwen and Llama.359360We will now extend the example to a more exciting use case: building a flight booking agent. This agent will be able to search for appropriate flights and book them automatically.361362To do this, we will first download the Mistral model using KerasHub. For agentic AI with Mistral, low-level access to tokenization is necessary due to the use of control tokens. Therefore, we will instantiate the tokenizer and model separately, and disable the preprocessor for the model.363"""364365tokenizer = keras_hub.tokenizers.MistralTokenizer.from_preset(366"kaggle://keras/mistral/keras/mistral_0.3_instruct_7b_en"367)368369try:370mistral = keras_hub.models.MistralCausalLM.from_preset(371"kaggle://keras/mistral/keras/mistral_0.3_instruct_7b_en", preprocessor=None372)373print("✅ Mistral model loaded successfully")374except Exception as e:375print(f"❌ Error loading Mistral model: {e}")376print("Please ensure you have the correct model preset and sufficient resources.")377raise378379"""380Next, we'll define functions for tokenization. The `preprocess` function will take a tokenized conversation in list form and format it correctly for the model. We'll also create an additional function, `encode_instruction`, for tokenizing text and adding instruction control tokens.381"""382383384def preprocess(messages, sequence_length=8192):385"""Preprocess tokenized messages for the Mistral model.386387Args:388messages: List of tokenized message sequences389sequence_length: Maximum sequence length for the model390391Returns:392Dictionary containing token_ids and padding_mask393"""394concatd = np.expand_dims(np.concatenate(messages), 0)395396# Truncate if the sequence is too long397if concatd.shape[1] > sequence_length:398concatd = concatd[:, :sequence_length]399400# Calculate padding needed401padding_needed = max(0, sequence_length - concatd.shape[1])402403return {404"token_ids": np.pad(concatd, ((0, 0), (0, padding_needed))),405"padding_mask": np.expand_dims(406np.arange(sequence_length) < concatd.shape[1], 0407).astype(int),408}409410411def encode_instruction(text):412"""Encode instruction text with Mistral control tokens.413414Args:415text: The instruction text to encode416417Returns:418List of tokenized sequences with instruction control tokens419"""420return [421[tokenizer.token_to_id("[INST]")],422tokenizer(text),423[tokenizer.token_to_id("[/INST]")],424]425426427"""428Now, we'll define a function, `try_parse_funccall`, to handle the model's function calls. These calls are identified by the `[TOOL_CALLS]` control token. The function will parse the subsequent data, which is in JSON format. Mistral also requires us to add a random call ID to each function call. Finally, the function will call the matching tool and encode its results using the `[TOOL_RESULTS]` control token.429"""430431432def try_parse_funccall(response):433"""Parse function calls from Mistral model response and execute tools.434435Args:436response: Tokenized model response437438Returns:439List of tokenized sequences including tool results440"""441# find the tool call in the response, if any442tool_call_id = tokenizer.token_to_id("[TOOL_CALLS]")443pos = np.where(response == tool_call_id)[0]444if not len(pos):445return [response]446pos = pos[0]447448try:449decoder = json.JSONDecoder()450tool_calls, _ = decoder.raw_decode(tokenizer.detokenize(response[pos + 1 :]))451if not isinstance(tool_calls, list) or not all(452isinstance(item, dict) for item in tool_calls453):454return [response]455456res = [] # Initialize result list457# assign a random call ID458for call in tool_calls:459call["id"] = "".join(460random.choices(string.ascii_letters + string.digits, k=9)461)462if call["name"] not in tools:463continue # Skip unknown tools464res.append([tokenizer.token_to_id("[TOOL_RESULTS]")])465res.append(466tokenizer(467json.dumps(468{469"content": tools[call["name"]](**call["arguments"]),470"call_id": call["id"],471}472)473)474)475res.append([tokenizer.token_to_id("[/TOOL_RESULTS]")])476return res477except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e:478# Log the error for debugging479print(f"Error parsing tool call: {e}")480return [response]481482483"""484We will extend our set of tools to include functions for currency conversion, finding flights, and booking flights. For this example, we'll use mock implementations for these functions, meaning they will return dummy data instead of interacting with real services.485"""486487tools = {488"convert_currency": lambda amount, currency, new_currency: (489f"{amount*USD_TO_EUR_RATE:.2f}"490if currency == "USD" and new_currency == "EUR"491else (492f"{amount/USD_TO_EUR_RATE:.2f}"493if currency == "EUR" and new_currency == "USD"494else f"Error: Unsupported conversion from {currency} to {new_currency}"495)496),497"find_flights": lambda origin, destination, date: [498{"id": 1, "price": "USD 220", "stops": 2, "duration": 4.5},499{"id": 2, "price": "USD 22", "stops": 1, "duration": 2.0},500{"id": 3, "price": "USD 240", "stops": 2, "duration": 13.2},501],502"book_flight": lambda id: {503"status": "success",504"message": f"Flight {id} booked successfully",505},506}507508"""509It's crucial to inform the model about these available functions at the very beginning of the conversation. To do this, we will define the available tools in a specific JSON format, as shown in the following code block.510"""511512tool_definitions = [513{514"type": "function",515"function": {516"name": "convert_currency",517"description": "Convert the currency with the latest exchange rate",518"parameters": {519"type": "object",520"properties": {521"amount": {"type": "number", "description": "The amount"},522"currency": {523"type": "string",524"description": "The currency to convert from",525},526"new_currency": {527"type": "string",528"description": "The currency to convert to",529},530},531"required": ["amount", "currency", "new_currency"],532},533},534},535{536"type": "function",537"function": {538"name": "find_flights",539"description": "Query price, time, number of stopovers and duration in hours for flights for a given date",540"parameters": {541"type": "object",542"properties": {543"origin": {544"type": "string",545"description": "The city to depart from",546},547"destination": {548"type": "string",549"description": "The destination city",550},551"date": {552"type": "string",553"description": "The date in YYYYMMDD format",554},555},556"required": ["origin", "destination", "date"],557},558},559},560{561"type": "function",562"function": {563"name": "book_flight",564"description": "Book the flight with the given id",565"parameters": {566"type": "object",567"properties": {568"id": {569"type": "number",570"description": "The numeric id of the flight to book",571},572},573"required": ["id"],574},575},576},577]578579"""580We will define the conversation as a `messages` list. At the very beginning of this list, we need to include a Beginning-Of-Sequence (BOS) token. This is followed by the tool definitions, which must be wrapped in `[AVAILABLE_TOOLS]` and `[/AVAILABLE_TOOLS]` control tokens.581"""582583messages = [584[tokenizer.token_to_id("<s>")],585[tokenizer.token_to_id("[AVAILABLE_TOOLS]")],586tokenizer(json.dumps(tool_definitions)),587[tokenizer.token_to_id("[/AVAILABLE_TOOLS]")],588]589590"""591Now, let's get started! We will task the model with the following: **Book the most comfortable flight from Linz to London on the 24th of July 2025, but only if it costs less than 20€ as of the latest exchange rate.**592"""593594messages.extend(595encode_instruction(596"Book the most comfortable flight from Linz to London on the 24th of July 2025, but only if it costs less than 20€ as of the latest exchange rate."597)598)599600"""601In an agentic AI system, the model interacts with its tools through a sequence of messages. We will continue to handle these messages until the flight is successfully booked.602For educational purposes, we will output the tool calls issued by the model; typically, a user would not see this level of detail. It's important to note that after the tool call JSON, the data must be truncated. If not, a less capable model may 'babble', outputting redundant or confused data.603"""604605flight_booked = False606max_iterations = 10 # Prevent infinite loops607iteration_count = 0608609while not flight_booked and iteration_count < max_iterations:610iteration_count += 1611# query the model612res = mistral.generate(613preprocess(messages), max_length=8192, stop_token_ids=[2], strip_prompt=True614)615# output the model's response, add separator line for legibility616response_text = tokenizer.detokenize(617res["token_ids"][0, : np.argmax(~res["padding_mask"])]618)619print(response_text, f"\n\n\n{'-'*100}\n\n")620621# Check for tool calls and track booking status622tool_call_id = tokenizer.token_to_id("[TOOL_CALLS]")623pos = np.where(res["token_ids"][0] == tool_call_id)[0]624if len(pos) > 0:625try:626decoder = json.JSONDecoder()627tool_calls, _ = decoder.raw_decode(628tokenizer.detokenize(res["token_ids"][0][pos[0] + 1 :])629)630if isinstance(tool_calls, list):631for call in tool_calls:632if isinstance(call, dict) and call.get("name") == "book_flight":633# Check if book_flight was called successfully634flight_booked = True635break636except (json.JSONDecodeError, KeyError, TypeError, ValueError):637pass638639# perform tool calls and extend `messages`640messages.extend(try_parse_funccall(res["token_ids"][0]))641642if not flight_booked:643print("Maximum iterations reached. Flight booking was not completed.")644645"""646For understandability, here's the conversation as received by the model, i.e. when truncating after the tool calling JSON:647648* **User:**649```650Book the most comfortable flight from Linz to London on the 24th of July 2025, but only if it costs less than 20€ as of the latest exchange rate.651```652653* **Model:**654```655[{"name": "find_flights", "arguments": {"origin": "Linz", "destination": "London", "date": "20250724"}}]656```657* **Tool Output:**658```659[{"id": 1, "price": "USD 220", "stops": 2, "duration": 4.5}, {"id": 2, "price": "USD 22", "stops": 1, "duration": 2.0}, {"id": 3, "price": "USD 240", "stops": 2, "duration": 13.2}]660```661* **Model:**662```663Now let's convert the price from USD to EUR using the latest exchange rate:664665[{"name": "convert_currency", "arguments": {"amount": 22, "currency": "USD", "new_currency": "EUR"}}]666```667* **Tool Output:**668```669"18.70"670```671* **Model:**672```673The price of the flight with the id 2 in EUR is 18.70. Since it is below the 20€ limit, let's book this flight:674675[{"name": "book_flight", "arguments": {"id": 2}}]676```677678It's important to acknowledge that you might have to run the model a few times to obtain a good output as depicted above. As a 7-billion parameter model, Mistral may still make several mistakes, such as misinterpreting data, outputting malformed tool calls, or making incorrect decisions. However, the continued development in this field paves the way for increasingly powerful agentic AI in the future.679"""680681"""682## Conclusion683684Tool calling is a powerful feature that allows large language models to interact with the real world, access live data, and perform complex calculations. By defining a set of tools and telling the model about them, you can create sophisticated applications that go far beyond simple text generation.685"""686687688