Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/AUTOMATIC1111_files/blocks.py
540 views
1
from __future__ import annotations
2
3
import copy
4
import inspect
5
import json
6
import os
7
import random
8
import secrets
9
import sys
10
import threading
11
import time
12
import warnings
13
import webbrowser
14
from abc import abstractmethod
15
from collections import defaultdict
16
from pathlib import Path
17
from types import ModuleType
18
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, cast
19
20
import anyio
21
import requests
22
from anyio import CapacityLimiter
23
from gradio_client import serializing
24
from gradio_client import utils as client_utils
25
from gradio_client.documentation import document, set_documentation_group
26
from packaging import version
27
28
from gradio import (
29
analytics,
30
components,
31
external,
32
networking,
33
queueing,
34
routes,
35
strings,
36
themes,
37
utils,
38
wasm_utils,
39
)
40
from gradio.context import Context
41
from gradio.deprecation import check_deprecated_parameters, warn_deprecation
42
from gradio.exceptions import (
43
DuplicateBlockError,
44
InvalidApiNameError,
45
InvalidBlockError,
46
)
47
from gradio.helpers import EventData, create_tracker, skip, special_args
48
from gradio.themes import Default as DefaultTheme
49
from gradio.themes import ThemeClass as Theme
50
from gradio.tunneling import (
51
BINARY_FILENAME,
52
BINARY_FOLDER,
53
BINARY_PATH,
54
BINARY_URL,
55
CURRENT_TUNNELS,
56
)
57
from gradio.utils import (
58
GRADIO_VERSION,
59
TupleNoPrint,
60
check_function_inputs_match,
61
component_or_layout_class,
62
concurrency_count_warning,
63
delete_none,
64
get_cancel_function,
65
get_continuous_fn,
66
)
67
68
try:
69
import spaces # type: ignore
70
except Exception:
71
spaces = None
72
73
set_documentation_group("blocks")
74
75
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
76
from fastapi.applications import FastAPI
77
78
from gradio.components import Component
79
80
BUILT_IN_THEMES: dict[str, Theme] = {
81
t.name: t
82
for t in [
83
themes.Base(),
84
themes.Default(),
85
themes.Monochrome(),
86
themes.Soft(),
87
themes.Glass(),
88
]
89
}
90
91
92
class Block:
93
def __init__(
94
self,
95
*,
96
render: bool = True,
97
elem_id: str | None = None,
98
elem_classes: list[str] | str | None = None,
99
visible: bool = True,
100
root_url: str | None = None, # URL that is prepended to all file paths
101
_skip_init_processing: bool = False, # Used for loading from Spaces
102
**kwargs,
103
):
104
self._id = Context.id
105
Context.id += 1
106
self.visible = visible
107
self.elem_id = elem_id
108
self.elem_classes = (
109
[elem_classes] if isinstance(elem_classes, str) else elem_classes
110
)
111
self.root_url = root_url
112
self.share_token = secrets.token_urlsafe(32)
113
self._skip_init_processing = _skip_init_processing
114
self.parent: BlockContext | None = None
115
self.is_rendered: bool = False
116
117
if render:
118
self.render()
119
check_deprecated_parameters(self.__class__.__name__, kwargs=kwargs)
120
121
def render(self):
122
"""
123
Adds self into appropriate BlockContext
124
"""
125
if Context.root_block is not None and self._id in Context.root_block.blocks:
126
raise DuplicateBlockError(
127
f"A block with id: {self._id} has already been rendered in the current Blocks."
128
)
129
if Context.block is not None:
130
Context.block.add(self)
131
if Context.root_block is not None:
132
Context.root_block.blocks[self._id] = self
133
self.is_rendered = True
134
if isinstance(self, components.IOComponent):
135
Context.root_block.temp_file_sets.append(self.temp_files)
136
return self
137
138
def unrender(self):
139
"""
140
Removes self from BlockContext if it has been rendered (otherwise does nothing).
141
Removes self from the layout and collection of blocks, but does not delete any event triggers.
142
"""
143
if Context.block is not None:
144
try:
145
Context.block.children.remove(self)
146
except ValueError:
147
pass
148
if Context.root_block is not None:
149
try:
150
del Context.root_block.blocks[self._id]
151
self.is_rendered = False
152
except KeyError:
153
pass
154
return self
155
156
def get_block_name(self) -> str:
157
"""
158
Gets block's class name.
159
160
If it is template component it gets the parent's class name.
161
162
@return: class name
163
"""
164
return (
165
self.__class__.__base__.__name__.lower()
166
if hasattr(self, "is_template")
167
else self.__class__.__name__.lower()
168
)
169
170
def get_expected_parent(self) -> type[BlockContext] | None:
171
return None
172
173
def set_event_trigger(
174
self,
175
event_name: str,
176
fn: Callable | None,
177
inputs: Component | list[Component] | set[Component] | None,
178
outputs: Component | list[Component] | None,
179
preprocess: bool = True,
180
postprocess: bool = True,
181
scroll_to_output: bool = False,
182
show_progress: str = "full",
183
api_name: str | None | Literal[False] = None,
184
js: str | None = None,
185
no_target: bool = False,
186
queue: bool | None = None,
187
batch: bool = False,
188
max_batch_size: int = 4,
189
cancels: list[int] | None = None,
190
every: float | None = None,
191
collects_event_data: bool | None = None,
192
trigger_after: int | None = None,
193
trigger_only_on_success: bool = False,
194
) -> tuple[dict[str, Any], int]:
195
"""
196
Adds an event to the component's dependencies.
197
Parameters:
198
event_name: event name
199
fn: Callable function
200
inputs: input list
201
outputs: output list
202
preprocess: whether to run the preprocess methods of components
203
postprocess: whether to run the postprocess methods of components
204
scroll_to_output: whether to scroll to output of dependency on trigger
205
show_progress: whether to show progress animation while running.
206
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
207
js: Experimental parameter (API may change): Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
208
no_target: if True, sets "targets" to [], used for Blocks "load" event
209
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
210
batch: whether this function takes in a batch of inputs
211
max_batch_size: the maximum batch size to send to the function
212
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
213
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
214
collects_event_data: whether to collect event data for this event
215
trigger_after: if set, this event will be triggered after 'trigger_after' function index
216
trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set)
217
Returns: dependency information, dependency index
218
"""
219
# Support for singular parameter
220
if isinstance(inputs, set):
221
inputs_as_dict = True
222
inputs = sorted(inputs, key=lambda x: x._id)
223
else:
224
inputs_as_dict = False
225
if inputs is None:
226
inputs = []
227
elif not isinstance(inputs, list):
228
inputs = [inputs]
229
230
if isinstance(outputs, set):
231
outputs = sorted(outputs, key=lambda x: x._id)
232
else:
233
if outputs is None:
234
outputs = []
235
elif not isinstance(outputs, list):
236
outputs = [outputs]
237
238
if fn is not None and not cancels:
239
check_function_inputs_match(fn, inputs, inputs_as_dict)
240
241
if Context.root_block is None:
242
raise AttributeError(
243
f"{event_name}() and other events can only be called within a Blocks context."
244
)
245
if every is not None and every <= 0:
246
raise ValueError("Parameter every must be positive or None")
247
if every and batch:
248
raise ValueError(
249
f"Cannot run {event_name} event in a batch and every {every} seconds. "
250
"Either batch is True or every is non-zero but not both."
251
)
252
253
if every and fn:
254
fn = get_continuous_fn(fn, every)
255
elif every:
256
raise ValueError("Cannot set a value for `every` without a `fn`.")
257
258
_, progress_index, event_data_index = (
259
special_args(fn) if fn else (None, None, None)
260
)
261
Context.root_block.fns.append(
262
BlockFunction(
263
fn,
264
inputs,
265
outputs,
266
preprocess,
267
postprocess,
268
inputs_as_dict,
269
progress_index is not None,
270
)
271
)
272
if api_name is not None and api_name is not False:
273
api_name_ = utils.append_unique_suffix(
274
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
275
)
276
if api_name != api_name_:
277
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
278
api_name = api_name_
279
280
if collects_event_data is None:
281
collects_event_data = event_data_index is not None
282
283
dependency = {
284
"targets": [self._id] if not no_target else [],
285
"trigger": event_name,
286
"inputs": [block._id for block in inputs],
287
"outputs": [block._id for block in outputs],
288
"backend_fn": fn is not None,
289
"js": js,
290
"queue": False if fn is None else queue,
291
"api_name": api_name,
292
"scroll_to_output": False if utils.get_space() else scroll_to_output,
293
"show_progress": show_progress,
294
"every": every,
295
"batch": batch,
296
"max_batch_size": max_batch_size,
297
"cancels": cancels or [],
298
"types": {
299
"continuous": bool(every),
300
"generator": inspect.isgeneratorfunction(fn) or bool(every),
301
},
302
"collects_event_data": collects_event_data,
303
"trigger_after": trigger_after,
304
"trigger_only_on_success": trigger_only_on_success,
305
}
306
Context.root_block.dependencies.append(dependency)
307
return dependency, len(Context.root_block.dependencies) - 1
308
309
def get_config(self):
310
return {
311
"visible": self.visible,
312
"elem_id": self.elem_id,
313
"elem_classes": self.elem_classes,
314
"root_url": self.root_url,
315
}
316
317
@staticmethod
318
@abstractmethod
319
def update(**kwargs) -> dict:
320
return {}
321
322
@classmethod
323
def get_specific_update(cls, generic_update: dict[str, Any]) -> dict:
324
generic_update = generic_update.copy()
325
del generic_update["__type__"]
326
specific_update = cls.update(**generic_update)
327
return specific_update
328
329
330
class BlockContext(Block):
331
def __init__(
332
self,
333
visible: bool = True,
334
render: bool = True,
335
**kwargs,
336
):
337
"""
338
Parameters:
339
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
340
render: If False, this will not be included in the Blocks config file at all.
341
"""
342
self.children: list[Block] = []
343
Block.__init__(self, visible=visible, render=render, **kwargs)
344
345
def add_child(self, child: Block):
346
self.children.append(child)
347
348
def __enter__(self):
349
self.parent = Context.block
350
Context.block = self
351
return self
352
353
def add(self, child: Block):
354
child.parent = self
355
self.children.append(child)
356
357
def fill_expected_parents(self):
358
children = []
359
pseudo_parent = None
360
for child in self.children:
361
expected_parent = child.get_expected_parent()
362
if not expected_parent or isinstance(self, expected_parent):
363
pseudo_parent = None
364
children.append(child)
365
else:
366
if pseudo_parent is not None and isinstance(
367
pseudo_parent, expected_parent
368
):
369
pseudo_parent.add_child(child)
370
else:
371
pseudo_parent = expected_parent(render=False)
372
pseudo_parent.parent = self
373
children.append(pseudo_parent)
374
pseudo_parent.add_child(child)
375
if Context.root_block:
376
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
377
child.parent = pseudo_parent
378
self.children = children
379
380
def __exit__(self, *args):
381
if getattr(self, "allow_expected_parents", True):
382
self.fill_expected_parents()
383
Context.block = self.parent
384
385
def postprocess(self, y):
386
"""
387
Any postprocessing needed to be performed on a block context.
388
"""
389
return y
390
391
392
class BlockFunction:
393
def __init__(
394
self,
395
fn: Callable | None,
396
inputs: list[Component],
397
outputs: list[Component],
398
preprocess: bool,
399
postprocess: bool,
400
inputs_as_dict: bool,
401
tracks_progress: bool = False,
402
):
403
self.fn = fn
404
self.inputs = inputs
405
self.outputs = outputs
406
self.preprocess = preprocess
407
self.postprocess = postprocess
408
self.tracks_progress = tracks_progress
409
self.total_runtime = 0
410
self.total_runs = 0
411
self.inputs_as_dict = inputs_as_dict
412
self.name = getattr(fn, "__name__", "fn") if fn is not None else None
413
self.spaces_auto_wrap()
414
415
def spaces_auto_wrap(self):
416
if spaces is None:
417
return
418
if utils.get_space() is None:
419
return
420
self.fn = spaces.gradio_auto_wrap(self.fn)
421
422
def __str__(self):
423
return str(
424
{
425
"fn": self.name,
426
"preprocess": self.preprocess,
427
"postprocess": self.postprocess,
428
}
429
)
430
431
def __repr__(self):
432
return str(self)
433
434
435
class class_or_instancemethod(classmethod): # noqa: N801
436
def __get__(self, instance, type_):
437
descr_get = super().__get__ if instance is None else self.__func__.__get__
438
return descr_get(instance, type_)
439
440
441
def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
442
"""
443
Converts a dictionary of updates into a format that can be sent to the frontend.
444
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
445
Into -> {"__type__": "update", "value": 2.0, "mode": "static"}
446
447
Parameters:
448
block: The Block that is being updated with this update dictionary.
449
update_dict: The original update dictionary
450
postprocess: Whether to postprocess the "value" key of the update dictionary.
451
"""
452
if update_dict.get("__type__", "") == "generic_update":
453
update_dict = block.get_specific_update(update_dict)
454
if update_dict.get("value") is components._Keywords.NO_VALUE:
455
update_dict.pop("value")
456
interactive = update_dict.pop("interactive", None)
457
if interactive is not None:
458
update_dict["mode"] = "dynamic" if interactive else "static"
459
prediction_value = delete_none(update_dict, skip_value=True)
460
if "value" in prediction_value and postprocess:
461
assert isinstance(
462
block, components.IOComponent
463
), f"Component {block.__class__} does not support value"
464
prediction_value["value"] = block.postprocess(prediction_value["value"])
465
return prediction_value
466
467
468
def convert_component_dict_to_list(
469
outputs_ids: list[int], predictions: dict
470
) -> list | dict:
471
"""
472
Converts a dictionary of component updates into a list of updates in the order of
473
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
474
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
475
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
476
"""
477
keys_are_blocks = [isinstance(key, Block) for key in predictions]
478
if all(keys_are_blocks):
479
reordered_predictions = [skip() for _ in outputs_ids]
480
for component, value in predictions.items():
481
if component._id not in outputs_ids:
482
raise ValueError(
483
f"Returned component {component} not specified as output of function."
484
)
485
output_index = outputs_ids.index(component._id)
486
reordered_predictions[output_index] = value
487
predictions = utils.resolve_singleton(reordered_predictions)
488
elif any(keys_are_blocks):
489
raise ValueError(
490
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
491
)
492
return predictions
493
494
495
def get_api_info(config: dict, serialize: bool = True):
496
"""
497
Gets the information needed to generate the API docs from a Blocks config.
498
Parameters:
499
config: a Blocks config dictionary
500
serialize: If True, returns the serialized version of the typed information. If False, returns the raw version.
501
"""
502
api_info = {"named_endpoints": {}, "unnamed_endpoints": {}}
503
mode = config.get("mode", None)
504
after_new_format = version.parse(config.get("version", "2.0")) > version.Version(
505
"3.28.3"
506
)
507
508
for d, dependency in enumerate(config["dependencies"]):
509
dependency_info = {"parameters": [], "returns": []}
510
skip_endpoint = False
511
512
inputs = dependency["inputs"]
513
for i in inputs:
514
for component in config["components"]:
515
if component["id"] == i:
516
break
517
else:
518
skip_endpoint = True # if component not found, skip endpoint
519
break
520
type = component["type"]
521
if type in client_utils.SKIP_COMPONENTS:
522
continue
523
if (
524
not component.get("serializer")
525
and type not in serializing.COMPONENT_MAPPING
526
):
527
skip_endpoint = True # if component not serializable, skip endpoint
528
break
529
if type in client_utils.SKIP_COMPONENTS:
530
continue
531
label = component["props"].get("label", f"parameter_{i}")
532
# The config has the most specific API info (taking into account the parameters
533
# of the component), so we use that if it exists. Otherwise, we fallback to the
534
# Serializer's API info.
535
serializer = serializing.COMPONENT_MAPPING[type]()
536
if component.get("api_info") and after_new_format:
537
info = component["api_info"]
538
example = component["example_inputs"]["serialized"]
539
else:
540
assert isinstance(serializer, serializing.Serializable)
541
info = serializer.api_info()
542
example = serializer.example_inputs()["raw"]
543
python_info = info["info"]
544
if serialize and info["serialized_info"]:
545
python_info = serializer.serialized_info()
546
if (
547
isinstance(serializer, serializing.FileSerializable)
548
and component["props"].get("file_count", "single") != "single"
549
):
550
python_info = serializer._multiple_file_serialized_info()
551
552
python_type = client_utils.json_schema_to_python_type(python_info)
553
serializer_name = serializing.COMPONENT_MAPPING[type].__name__
554
dependency_info["parameters"].append(
555
{
556
"label": label,
557
"type": info["info"],
558
"python_type": {
559
"type": python_type,
560
"description": python_info.get("description", ""),
561
},
562
"component": type.capitalize(),
563
"example_input": example,
564
"serializer": serializer_name,
565
}
566
)
567
568
outputs = dependency["outputs"]
569
for o in outputs:
570
for component in config["components"]:
571
if component["id"] == o:
572
break
573
else:
574
skip_endpoint = True # if component not found, skip endpoint
575
break
576
type = component["type"]
577
if type in client_utils.SKIP_COMPONENTS:
578
continue
579
if (
580
not component.get("serializer")
581
and type not in serializing.COMPONENT_MAPPING
582
):
583
skip_endpoint = True # if component not serializable, skip endpoint
584
break
585
label = component["props"].get("label", f"value_{o}")
586
serializer = serializing.COMPONENT_MAPPING[type]()
587
if component.get("api_info") and after_new_format:
588
info = component["api_info"]
589
example = component["example_inputs"]["serialized"]
590
else:
591
assert isinstance(serializer, serializing.Serializable)
592
info = serializer.api_info()
593
example = serializer.example_inputs()["raw"]
594
python_info = info["info"]
595
if serialize and info["serialized_info"]:
596
python_info = serializer.serialized_info()
597
if (
598
isinstance(serializer, serializing.FileSerializable)
599
and component["props"].get("file_count", "single") != "single"
600
):
601
python_info = serializer._multiple_file_serialized_info()
602
python_type = client_utils.json_schema_to_python_type(python_info)
603
serializer_name = serializing.COMPONENT_MAPPING[type].__name__
604
dependency_info["returns"].append(
605
{
606
"label": label,
607
"type": info["info"],
608
"python_type": {
609
"type": python_type,
610
"description": python_info.get("description", ""),
611
},
612
"component": type.capitalize(),
613
"serializer": serializer_name,
614
}
615
)
616
617
if not dependency["backend_fn"]:
618
skip_endpoint = True
619
620
if skip_endpoint:
621
continue
622
if dependency["api_name"] is not None and dependency["api_name"] is not False:
623
api_info["named_endpoints"][f"/{dependency['api_name']}"] = dependency_info
624
elif (
625
dependency["api_name"] is False
626
or mode == "interface"
627
or mode == "tabbed_interface"
628
):
629
pass # Skip unnamed endpoints in interface mode
630
else:
631
api_info["unnamed_endpoints"][str(d)] = dependency_info
632
633
return api_info
634
635
636
@document("launch", "queue", "integrate", "load")
637
class Blocks(BlockContext):
638
"""
639
Blocks is Gradio's low-level API that allows you to create more custom web
640
applications and demos than Interfaces (yet still entirely in Python).
641
642
643
Compared to the Interface class, Blocks offers more flexibility and control over:
644
(1) the layout of components (2) the events that
645
trigger the execution of functions (3) data flows (e.g. inputs can trigger outputs,
646
which can trigger the next level of outputs). Blocks also offers ways to group
647
together related demos such as with tabs.
648
649
650
The basic usage of Blocks is as follows: create a Blocks object, then use it as a
651
context (with the "with" statement), and then define layouts, components, or events
652
within the Blocks context. Finally, call the launch() method to launch the demo.
653
654
Example:
655
import gradio as gr
656
def update(name):
657
return f"Welcome to Gradio, {name}!"
658
659
with gr.Blocks() as demo:
660
gr.Markdown("Start typing below and then click **Run** to see the output.")
661
with gr.Row():
662
inp = gr.Textbox(placeholder="What is your name?")
663
out = gr.Textbox()
664
btn = gr.Button("Run")
665
btn.click(fn=update, inputs=inp, outputs=out)
666
667
demo.launch()
668
Demos: blocks_hello, blocks_flipper, blocks_speech_text_sentiment, generate_english_german, sound_alert
669
Guides: blocks-and-event-listeners, controlling-layout, state-in-blocks, custom-CSS-and-JS, custom-interpretations-with-blocks, using-blocks-like-functions
670
"""
671
672
def __init__(
673
self,
674
theme: Theme | str | None = None,
675
analytics_enabled: bool | None = None,
676
mode: str = "blocks",
677
title: str = "Gradio",
678
css: str | None = None,
679
**kwargs,
680
):
681
"""
682
Parameters:
683
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the HF Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
684
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
685
mode: a human-friendly name for the kind of Blocks or Interface being created.
686
title: The tab title to display when this is opened in a browser window.
687
css: custom css or path to custom css file to apply to entire Blocks
688
"""
689
self.limiter = None
690
if theme is None:
691
theme = DefaultTheme()
692
elif isinstance(theme, str):
693
if theme.lower() in BUILT_IN_THEMES:
694
theme = BUILT_IN_THEMES[theme.lower()]
695
else:
696
try:
697
theme = Theme.from_hub(theme)
698
except Exception as e:
699
warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
700
theme = DefaultTheme()
701
if not isinstance(theme, Theme):
702
warnings.warn("Theme should be a class loaded from gradio.themes")
703
theme = DefaultTheme()
704
self.theme: Theme = theme
705
self.theme_css = theme._get_theme_css()
706
self.stylesheets = theme._stylesheets
707
self.encrypt = False
708
self.share = False
709
self.enable_queue = None
710
self.max_threads = 40
711
self.pending_streams = defaultdict(dict)
712
self.show_error = True
713
if css is not None and os.path.exists(css):
714
with open(css) as css_file:
715
self.css = css_file.read()
716
else:
717
self.css = css
718
719
# For analytics_enabled and allow_flagging: (1) first check for
720
# parameter, (2) check for env variable, (3) default to True/"manual"
721
self.analytics_enabled = (
722
analytics_enabled
723
if analytics_enabled is not None
724
else analytics.analytics_enabled()
725
)
726
if self.analytics_enabled:
727
if not wasm_utils.IS_WASM:
728
t = threading.Thread(target=analytics.version_check)
729
t.start()
730
else:
731
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
732
super().__init__(render=False, **kwargs)
733
self.blocks: dict[int, Block] = {}
734
self.fns: list[BlockFunction] = []
735
self.dependencies = []
736
self.mode = mode
737
738
self.is_running = False
739
self.local_url = None
740
self.share_url = None
741
self.width = None
742
self.height = None
743
self.api_open = True
744
745
self.space_id = utils.get_space()
746
self.favicon_path = None
747
self.auth = None
748
self.dev_mode = True
749
self.app_id = random.getrandbits(64)
750
self.temp_file_sets = []
751
self.title = title
752
self.show_api = True
753
754
# Only used when an Interface is loaded from a config
755
self.predict = None
756
self.input_components = None
757
self.output_components = None
758
self.__name__ = None
759
self.api_mode = None
760
self.progress_tracking = None
761
self.ssl_verify = True
762
763
self.allowed_paths = []
764
self.blocked_paths = []
765
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
766
self.root_urls = set()
767
768
if self.analytics_enabled:
769
is_custom_theme = not any(
770
self.theme.to_dict() == built_in_theme.to_dict()
771
for built_in_theme in BUILT_IN_THEMES.values()
772
)
773
data = {
774
"mode": self.mode,
775
"custom_css": self.css is not None,
776
"theme": self.theme.name,
777
"is_custom_theme": is_custom_theme,
778
"version": GRADIO_VERSION,
779
}
780
analytics.initiated_analytics(data)
781
782
@classmethod
783
def from_config(
784
cls,
785
config: dict,
786
fns: list[Callable],
787
root_url: str,
788
) -> Blocks:
789
"""
790
Factory method that creates a Blocks from a config and list of functions. Used
791
internally by the gradio.external.load() method.
792
793
Parameters:
794
config: a dictionary containing the configuration of the Blocks.
795
fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config.
796
root_url: an external url to use as a root URL when serving files for components in the Blocks.
797
"""
798
config = copy.deepcopy(config)
799
components_config = config["components"]
800
for component_config in components_config:
801
# for backwards compatibility, extract style into props
802
if "style" in component_config["props"]:
803
component_config["props"].update(component_config["props"]["style"])
804
del component_config["props"]["style"]
805
theme = config.get("theme", "default")
806
original_mapping: dict[int, Block] = {}
807
root_urls = {root_url}
808
809
def get_block_instance(id: int) -> Block:
810
for block_config in components_config:
811
if block_config["id"] == id:
812
break
813
else:
814
raise ValueError(f"Cannot find block with id {id}")
815
cls = component_or_layout_class(block_config["type"])
816
block_config["props"].pop("type", None)
817
block_config["props"].pop("name", None)
818
# If a Gradio app B is loaded into a Gradio app A, and B itself loads a
819
# Gradio app C, then the root_urls of the components in A need to be the
820
# URL of C, not B. The else clause below handles this case.
821
if block_config["props"].get("root_url") is None:
822
block_config["props"]["root_url"] = f"{root_url}/"
823
else:
824
root_urls.add(block_config["props"]["root_url"])
825
# Any component has already processed its initial value, so we skip that step here
826
block = cls(**block_config["props"], _skip_init_processing=True)
827
return block
828
829
def iterate_over_children(children_list):
830
for child_config in children_list:
831
id = child_config["id"]
832
block = get_block_instance(id)
833
834
original_mapping[id] = block
835
836
children = child_config.get("children")
837
if children is not None:
838
assert isinstance(
839
block, BlockContext
840
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
841
with block:
842
iterate_over_children(children)
843
844
derived_fields = ["types"]
845
846
with Blocks(theme=theme) as blocks:
847
# ID 0 should be the root Blocks component
848
original_mapping[0] = Context.root_block or blocks
849
850
iterate_over_children(config["layout"]["children"])
851
852
first_dependency = None
853
854
# add the event triggers
855
for dependency, fn in zip(config["dependencies"], fns):
856
# We used to add a "fake_event" to the config to cache examples
857
# without removing it. This was causing bugs in calling gr.load
858
# We fixed the issue by removing "fake_event" from the config in examples.py
859
# but we still need to skip these events when loading the config to support
860
# older demos
861
if dependency["trigger"] == "fake_event":
862
continue
863
for field in derived_fields:
864
dependency.pop(field, None)
865
targets = dependency.pop("targets")
866
trigger = dependency.pop("trigger")
867
dependency.pop("backend_fn")
868
dependency.pop("documentation", None)
869
dependency["inputs"] = [
870
original_mapping[i] for i in dependency["inputs"]
871
]
872
dependency["outputs"] = [
873
original_mapping[o] for o in dependency["outputs"]
874
]
875
dependency.pop("status_tracker", None)
876
dependency["preprocess"] = False
877
dependency["postprocess"] = False
878
879
for target in targets:
880
dependency = original_mapping[target].set_event_trigger(
881
event_name=trigger, fn=fn, **dependency
882
)[0]
883
if first_dependency is None:
884
first_dependency = dependency
885
886
# Allows some use of Interface-specific methods with loaded Spaces
887
if first_dependency and Context.root_block:
888
blocks.predict = [fns[0]]
889
blocks.input_components = [
890
Context.root_block.blocks[i] for i in first_dependency["inputs"]
891
]
892
blocks.output_components = [
893
Context.root_block.blocks[o] for o in first_dependency["outputs"]
894
]
895
blocks.__name__ = "Interface"
896
blocks.api_mode = True
897
898
blocks.root_urls = root_urls
899
return blocks
900
901
def __str__(self):
902
return self.__repr__()
903
904
def __repr__(self):
905
num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
906
repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
907
repr += f"\n{'-' * len(repr)}"
908
for d, dependency in enumerate(self.dependencies):
909
if dependency["backend_fn"]:
910
repr += f"\nfn_index={d}"
911
repr += "\n inputs:"
912
for input_id in dependency["inputs"]:
913
block = self.blocks[input_id]
914
repr += f"\n |-{block}"
915
repr += "\n outputs:"
916
for output_id in dependency["outputs"]:
917
block = self.blocks[output_id]
918
repr += f"\n |-{block}"
919
return repr
920
921
@property
922
def expects_oauth(self):
923
"""Return whether the app expects user to authenticate via OAuth."""
924
return any(
925
isinstance(block, (components.LoginButton, components.LogoutButton))
926
for block in self.blocks.values()
927
)
928
929
def render(self):
930
if Context.root_block is not None:
931
if self._id in Context.root_block.blocks:
932
raise DuplicateBlockError(
933
f"A block with id: {self._id} has already been rendered in the current Blocks."
934
)
935
overlapping_ids = set(Context.root_block.blocks).intersection(self.blocks)
936
for id in overlapping_ids:
937
# State components are allowed to be reused between Blocks
938
if not isinstance(self.blocks[id], components.State):
939
raise DuplicateBlockError(
940
"At least one block in this Blocks has already been rendered."
941
)
942
943
Context.root_block.blocks.update(self.blocks)
944
Context.root_block.fns.extend(self.fns)
945
dependency_offset = len(Context.root_block.dependencies)
946
for i, dependency in enumerate(self.dependencies):
947
api_name = dependency["api_name"]
948
if api_name is not None and api_name is not False:
949
api_name_ = utils.append_unique_suffix(
950
api_name,
951
[dep["api_name"] for dep in Context.root_block.dependencies],
952
)
953
if api_name != api_name_:
954
warnings.warn(
955
f"api_name {api_name} already exists, using {api_name_}"
956
)
957
dependency["api_name"] = api_name_
958
dependency["cancels"] = [
959
c + dependency_offset for c in dependency["cancels"]
960
]
961
if dependency.get("trigger_after") is not None:
962
dependency["trigger_after"] += dependency_offset
963
# Recreate the cancel function so that it has the latest
964
# dependency fn indices. This is necessary to properly cancel
965
# events in the backend
966
if dependency["cancels"]:
967
updated_cancels = [
968
Context.root_block.dependencies[i]
969
for i in dependency["cancels"]
970
]
971
new_fn = BlockFunction(
972
get_cancel_function(updated_cancels)[0],
973
[],
974
[],
975
False,
976
True,
977
False,
978
)
979
Context.root_block.fns[dependency_offset + i] = new_fn
980
Context.root_block.dependencies.append(dependency)
981
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
982
Context.root_block.root_urls.update(self.root_urls)
983
984
if Context.block is not None:
985
Context.block.children.extend(self.children)
986
return self
987
988
def is_callable(self, fn_index: int = 0) -> bool:
989
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
990
block_fn = self.fns[fn_index]
991
dependency = self.dependencies[fn_index]
992
993
if inspect.isasyncgenfunction(block_fn.fn):
994
return False
995
if inspect.isgeneratorfunction(block_fn.fn):
996
return False
997
for input_id in dependency["inputs"]:
998
block = self.blocks[input_id]
999
if getattr(block, "stateful", False):
1000
return False
1001
for output_id in dependency["outputs"]:
1002
block = self.blocks[output_id]
1003
if getattr(block, "stateful", False):
1004
return False
1005
1006
return True
1007
1008
def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
1009
"""
1010
Allows Blocks objects to be called as functions. Supply the parameters to the
1011
function as positional arguments. To choose which function to call, use the
1012
fn_index parameter, which must be a keyword argument.
1013
1014
Parameters:
1015
*inputs: the parameters to pass to the function
1016
fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
1017
api_name: The api_name of the dependency to call. Will take precedence over fn_index.
1018
"""
1019
if api_name is not None:
1020
inferred_fn_index = next(
1021
(
1022
i
1023
for i, d in enumerate(self.dependencies)
1024
if d.get("api_name") == api_name
1025
),
1026
None,
1027
)
1028
if inferred_fn_index is None:
1029
raise InvalidApiNameError(
1030
f"Cannot find a function with api_name {api_name}"
1031
)
1032
fn_index = inferred_fn_index
1033
if not (self.is_callable(fn_index)):
1034
raise ValueError(
1035
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
1036
)
1037
1038
inputs = list(inputs)
1039
processed_inputs = self.serialize_data(fn_index, inputs)
1040
batch = self.dependencies[fn_index]["batch"]
1041
if batch:
1042
processed_inputs = [[inp] for inp in processed_inputs]
1043
1044
outputs = client_utils.synchronize_async(
1045
self.process_api,
1046
fn_index=fn_index,
1047
inputs=processed_inputs,
1048
request=None,
1049
state={},
1050
)
1051
outputs = outputs["data"]
1052
1053
if batch:
1054
outputs = [out[0] for out in outputs]
1055
1056
processed_outputs = self.deserialize_data(fn_index, outputs)
1057
processed_outputs = utils.resolve_singleton(processed_outputs)
1058
1059
return processed_outputs
1060
1061
async def call_function(
1062
self,
1063
fn_index: int,
1064
processed_input: list[Any],
1065
iterator: AsyncIterator[Any] | None = None,
1066
requests: routes.Request | list[routes.Request] | None = None,
1067
event_id: str | None = None,
1068
event_data: EventData | None = None,
1069
):
1070
"""
1071
Calls function with given index and preprocessed input, and measures process time.
1072
Parameters:
1073
fn_index: index of function to call
1074
processed_input: preprocessed input to pass to function
1075
iterator: iterator to use if function is a generator
1076
requests: requests to pass to function
1077
event_id: id of event in queue
1078
event_data: data associated with event trigger
1079
"""
1080
block_fn = self.fns[fn_index]
1081
assert block_fn.fn, f"function with index {fn_index} not defined."
1082
is_generating = False
1083
request = requests[0] if isinstance(requests, list) else requests
1084
start = time.time()
1085
fn = utils.get_function_with_locals(block_fn.fn, self, event_id)
1086
1087
if iterator is None: # If not a generator function that has already run
1088
if block_fn.inputs_as_dict:
1089
processed_input = [dict(zip(block_fn.inputs, processed_input))]
1090
1091
processed_input, progress_index, _ = special_args(
1092
block_fn.fn, processed_input, request, event_data
1093
)
1094
progress_tracker = (
1095
processed_input[progress_index] if progress_index is not None else None
1096
)
1097
1098
if progress_tracker is not None and progress_index is not None:
1099
progress_tracker, fn = create_tracker(
1100
self, event_id, fn, progress_tracker.track_tqdm
1101
)
1102
processed_input[progress_index] = progress_tracker
1103
1104
if inspect.iscoroutinefunction(fn):
1105
prediction = await fn(*processed_input)
1106
else:
1107
prediction = await anyio.to_thread.run_sync(
1108
fn, *processed_input, limiter=self.limiter
1109
)
1110
else:
1111
prediction = None
1112
1113
if inspect.isgeneratorfunction(fn) or inspect.isasyncgenfunction(fn):
1114
if not self.enable_queue:
1115
raise ValueError("Need to enable queue to use generators.")
1116
try:
1117
if iterator is None:
1118
iterator = cast(AsyncIterator[Any], prediction)
1119
if inspect.isgenerator(iterator):
1120
iterator = utils.SyncToAsyncIterator(iterator, self.limiter)
1121
prediction = await utils.async_iteration(iterator)
1122
is_generating = True
1123
except StopAsyncIteration:
1124
n_outputs = len(self.dependencies[fn_index].get("outputs"))
1125
prediction = (
1126
components._Keywords.FINISHED_ITERATING
1127
if n_outputs == 1
1128
else (components._Keywords.FINISHED_ITERATING,) * n_outputs
1129
)
1130
iterator = None
1131
1132
duration = time.time() - start
1133
1134
return {
1135
"prediction": prediction,
1136
"duration": duration,
1137
"is_generating": is_generating,
1138
"iterator": iterator,
1139
}
1140
1141
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
1142
dependency = self.dependencies[fn_index]
1143
processed_input = []
1144
1145
for i, input_id in enumerate(dependency["inputs"]):
1146
try:
1147
block = self.blocks[input_id]
1148
except KeyError as e:
1149
raise InvalidBlockError(
1150
f"Input component with id {input_id} used in {dependency['trigger']}() event is not defined in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
1151
) from e
1152
assert isinstance(
1153
block, components.IOComponent
1154
), f"{block.__class__} Component with id {input_id} not a valid input component."
1155
serialized_input = block.serialize(inputs[i])
1156
processed_input.append(serialized_input)
1157
1158
return processed_input
1159
1160
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
1161
dependency = self.dependencies[fn_index]
1162
predictions = []
1163
1164
for o, output_id in enumerate(dependency["outputs"]):
1165
try:
1166
block = self.blocks[output_id]
1167
except KeyError as e:
1168
raise InvalidBlockError(
1169
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
1170
) from e
1171
assert isinstance(
1172
block, components.IOComponent
1173
), f"{block.__class__} Component with id {output_id} not a valid output component."
1174
deserialized = block.deserialize(
1175
outputs[o],
1176
save_dir=block.DEFAULT_TEMP_DIR,
1177
root_url=block.root_url,
1178
hf_token=Context.hf_token,
1179
)
1180
predictions.append(deserialized)
1181
1182
return predictions
1183
1184
def validate_inputs(self, fn_index: int, inputs: list[Any]):
1185
block_fn = self.fns[fn_index]
1186
dependency = self.dependencies[fn_index]
1187
1188
dep_inputs = dependency["inputs"]
1189
1190
# This handles incorrect inputs when args are changed by a JS function
1191
# Only check not enough args case, ignore extra arguments (for now)
1192
# TODO: make this stricter?
1193
if len(inputs) < len(dep_inputs):
1194
name = (
1195
f" ({block_fn.name})"
1196
if block_fn.name and block_fn.name != "<lambda>"
1197
else ""
1198
)
1199
1200
wanted_args = []
1201
received_args = []
1202
for input_id in dep_inputs:
1203
block = self.blocks[input_id]
1204
wanted_args.append(str(block))
1205
for inp in inputs:
1206
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
1207
received_args.append(v)
1208
1209
wanted = ", ".join(wanted_args)
1210
received = ", ".join(received_args)
1211
1212
# JS func didn't pass enough arguments
1213
raise ValueError(
1214
f"""An event handler{name} didn't receive enough input values (needed: {len(dep_inputs)}, got: {len(inputs)}).
1215
Check if the event handler calls a Javascript function, and make sure its return value is correct.
1216
Wanted inputs:
1217
[{wanted}]
1218
Received inputs:
1219
[{received}]"""
1220
)
1221
1222
def preprocess_data(self, fn_index: int, inputs: list[Any], state: dict[int, Any]):
1223
block_fn = self.fns[fn_index]
1224
dependency = self.dependencies[fn_index]
1225
1226
self.validate_inputs(fn_index, inputs)
1227
1228
if block_fn.preprocess:
1229
processed_input = []
1230
for i, input_id in enumerate(dependency["inputs"]):
1231
try:
1232
block = self.blocks[input_id]
1233
except KeyError as e:
1234
raise InvalidBlockError(
1235
f"Input component with id {input_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
1236
) from e
1237
assert isinstance(
1238
block, components.Component
1239
), f"{block.__class__} Component with id {input_id} not a valid input component."
1240
if getattr(block, "stateful", False):
1241
processed_input.append(state.get(input_id))
1242
else:
1243
processed_input.append(block.preprocess(inputs[i]))
1244
else:
1245
processed_input = inputs
1246
return processed_input
1247
1248
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
1249
block_fn = self.fns[fn_index]
1250
dependency = self.dependencies[fn_index]
1251
1252
dep_outputs = dependency["outputs"]
1253
1254
if type(predictions) is not list and type(predictions) is not tuple:
1255
predictions = [predictions]
1256
1257
if len(predictions) < len(dep_outputs):
1258
name = (
1259
f" ({block_fn.name})"
1260
if block_fn.name and block_fn.name != "<lambda>"
1261
else ""
1262
)
1263
1264
wanted_args = []
1265
received_args = []
1266
for output_id in dep_outputs:
1267
block = self.blocks[output_id]
1268
wanted_args.append(str(block))
1269
for pred in predictions:
1270
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
1271
received_args.append(v)
1272
1273
wanted = ", ".join(wanted_args)
1274
received = ", ".join(received_args)
1275
1276
raise ValueError(
1277
f"""An event handler{name} didn't receive enough output values (needed: {len(dep_outputs)}, received: {len(predictions)}).
1278
Wanted outputs:
1279
[{wanted}]
1280
Received outputs:
1281
[{received}]"""
1282
)
1283
1284
def postprocess_data(
1285
self, fn_index: int, predictions: list | dict, state: dict[int, Any]
1286
):
1287
block_fn = self.fns[fn_index]
1288
dependency = self.dependencies[fn_index]
1289
batch = dependency["batch"]
1290
1291
if type(predictions) is dict and len(predictions) > 0:
1292
predictions = convert_component_dict_to_list(
1293
dependency["outputs"], predictions
1294
)
1295
1296
if len(dependency["outputs"]) == 1 and not (batch):
1297
predictions = [
1298
predictions,
1299
]
1300
1301
self.validate_outputs(fn_index, predictions) # type: ignore
1302
1303
output = []
1304
for i, output_id in enumerate(dependency["outputs"]):
1305
try:
1306
if predictions[i] is components._Keywords.FINISHED_ITERATING:
1307
output.append(None)
1308
continue
1309
except (IndexError, KeyError) as err:
1310
raise ValueError(
1311
"Number of output components does not match number "
1312
f"of values returned from from function {block_fn.name}"
1313
) from err
1314
1315
try:
1316
block = self.blocks[output_id]
1317
except KeyError as e:
1318
raise InvalidBlockError(
1319
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
1320
) from e
1321
1322
if getattr(block, "stateful", False):
1323
if not utils.is_update(predictions[i]):
1324
state[output_id] = predictions[i]
1325
output.append(None)
1326
else:
1327
prediction_value = predictions[i]
1328
if utils.is_update(prediction_value):
1329
assert isinstance(prediction_value, dict)
1330
prediction_value = postprocess_update_dict(
1331
block=block,
1332
update_dict=prediction_value,
1333
postprocess=block_fn.postprocess,
1334
)
1335
elif block_fn.postprocess:
1336
assert isinstance(
1337
block, components.Component
1338
), f"{block.__class__} Component with id {output_id} not a valid output component."
1339
prediction_value = block.postprocess(prediction_value)
1340
output.append(prediction_value)
1341
1342
return output
1343
1344
def handle_streaming_outputs(
1345
self,
1346
fn_index: int,
1347
data: list,
1348
session_hash: str | None,
1349
run: int | None,
1350
) -> list:
1351
if session_hash is None or run is None:
1352
return data
1353
if run not in self.pending_streams[session_hash]:
1354
self.pending_streams[session_hash][run] = {}
1355
stream_run = self.pending_streams[session_hash][run]
1356
1357
from gradio.events import StreamableOutput
1358
1359
for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
1360
block = self.blocks[output_id]
1361
if isinstance(block, StreamableOutput) and block.streaming:
1362
first_chunk = output_id not in stream_run
1363
binary_data, output_data = block.stream_output(
1364
data[i], f"{session_hash}/{run}/{output_id}", first_chunk
1365
)
1366
if first_chunk:
1367
stream_run[output_id] = []
1368
self.pending_streams[session_hash][run][output_id].append(binary_data)
1369
data[i] = output_data
1370
return data
1371
1372
async def process_api(
1373
self,
1374
fn_index: int,
1375
inputs: list[Any],
1376
state: dict[int, Any],
1377
request: routes.Request | list[routes.Request] | None = None,
1378
iterators: dict[int, Any] | None = None,
1379
session_hash: str | None = None,
1380
event_id: str | None = None,
1381
event_data: EventData | None = None,
1382
) -> dict[str, Any]:
1383
"""
1384
Processes API calls from the frontend. First preprocesses the data,
1385
then runs the relevant function, then postprocesses the output.
1386
Parameters:
1387
fn_index: Index of function to run.
1388
inputs: input data received from the frontend
1389
state: data stored from stateful components for session (key is input block id)
1390
request: the gr.Request object containing information about the network request (e.g. IP address, headers, query parameters, username)
1391
iterators: the in-progress iterators for each generator function (key is function index)
1392
event_id: id of event that triggered this API call
1393
event_data: data associated with the event trigger itself
1394
Returns: None
1395
"""
1396
block_fn = self.fns[fn_index]
1397
batch = self.dependencies[fn_index]["batch"]
1398
1399
if batch:
1400
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
1401
batch_sizes = [len(inp) for inp in inputs]
1402
batch_size = batch_sizes[0]
1403
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
1404
block_fn.fn
1405
):
1406
raise ValueError("Gradio does not support generators in batch mode.")
1407
if not all(x == batch_size for x in batch_sizes):
1408
raise ValueError(
1409
f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
1410
)
1411
if batch_size > max_batch_size:
1412
raise ValueError(
1413
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
1414
)
1415
1416
inputs = [
1417
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
1418
]
1419
result = await self.call_function(
1420
fn_index, list(zip(*inputs)), None, request, event_id, event_data
1421
)
1422
preds = result["prediction"]
1423
data = [
1424
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
1425
]
1426
data = list(zip(*data))
1427
is_generating, iterator = None, None
1428
else:
1429
old_iterator = iterators.get(fn_index, None) if iterators else None
1430
if old_iterator:
1431
inputs = []
1432
else:
1433
inputs = self.preprocess_data(fn_index, inputs, state)
1434
was_generating = old_iterator is not None
1435
result = await self.call_function(
1436
fn_index, inputs, old_iterator, request, event_id, event_data
1437
)
1438
data = self.postprocess_data(fn_index, result["prediction"], state)
1439
is_generating, iterator = result["is_generating"], result["iterator"]
1440
if is_generating or was_generating:
1441
data = self.handle_streaming_outputs(
1442
fn_index,
1443
data,
1444
session_hash=session_hash,
1445
run=id(old_iterator) if was_generating else id(iterator),
1446
)
1447
1448
block_fn.total_runtime += result["duration"]
1449
block_fn.total_runs += 1
1450
return {
1451
"data": data,
1452
"is_generating": is_generating,
1453
"iterator": iterator,
1454
"duration": result["duration"],
1455
"average_duration": block_fn.total_runtime / block_fn.total_runs,
1456
}
1457
1458
async def create_limiter(self):
1459
self.limiter = (
1460
None
1461
if self.max_threads == 40
1462
else CapacityLimiter(total_tokens=self.max_threads)
1463
)
1464
1465
def get_config(self):
1466
return {"type": "column"}
1467
1468
def get_config_file(self):
1469
config = {
1470
"version": routes.VERSION,
1471
"mode": self.mode,
1472
"dev_mode": self.dev_mode,
1473
"analytics_enabled": self.analytics_enabled,
1474
"components": [],
1475
"css": self.css,
1476
"title": self.title or "Gradio",
1477
"space_id": self.space_id,
1478
"enable_queue": getattr(self, "enable_queue", False), # launch attributes
1479
"show_error": getattr(self, "show_error", False),
1480
"show_api": self.show_api,
1481
"is_colab": utils.colab_check(),
1482
"stylesheets": self.stylesheets,
1483
"theme": self.theme.name,
1484
}
1485
1486
def get_layout(block):
1487
if not isinstance(block, BlockContext):
1488
return {"id": block._id}
1489
children_layout = []
1490
for child in block.children:
1491
children_layout.append(get_layout(child))
1492
return {"id": block._id, "children": children_layout}
1493
1494
config["layout"] = get_layout(self)
1495
1496
for _id, block in self.blocks.items():
1497
props = block.get_config() if hasattr(block, "get_config") else {}
1498
block_config = {
1499
"id": _id,
1500
"type": block.get_block_name(),
1501
"props": utils.delete_none(props),
1502
}
1503
serializer = utils.get_serializer_name(block)
1504
if serializer:
1505
assert isinstance(block, serializing.Serializable)
1506
block_config["serializer"] = serializer
1507
block_config["api_info"] = block.api_info() # type: ignore
1508
block_config["example_inputs"] = block.example_inputs() # type: ignore
1509
config["components"].append(block_config)
1510
config["dependencies"] = self.dependencies
1511
return config
1512
1513
def __enter__(self):
1514
if Context.block is None:
1515
Context.root_block = self
1516
self.parent = Context.block
1517
Context.block = self
1518
self.exited = False
1519
return self
1520
1521
def __exit__(self, *args):
1522
super().fill_expected_parents()
1523
Context.block = self.parent
1524
# Configure the load events before root_block is reset
1525
self.attach_load_events()
1526
if self.parent is None:
1527
Context.root_block = None
1528
else:
1529
self.parent.children.extend(self.children)
1530
self.config = self.get_config_file()
1531
self.app = routes.App.create_app(self)
1532
self.progress_tracking = any(block_fn.tracks_progress for block_fn in self.fns)
1533
self.exited = True
1534
1535
@class_or_instancemethod
1536
def load(
1537
self_or_cls, # noqa: N805
1538
fn: Callable | None = None,
1539
inputs: list[Component] | None = None,
1540
outputs: list[Component] | None = None,
1541
api_name: str | None | Literal[False] = None,
1542
scroll_to_output: bool = False,
1543
show_progress: str = "full",
1544
queue=None,
1545
batch: bool = False,
1546
max_batch_size: int = 4,
1547
preprocess: bool = True,
1548
postprocess: bool = True,
1549
every: float | None = None,
1550
_js: str | None = None,
1551
*,
1552
name: str | None = None,
1553
src: str | None = None,
1554
api_key: str | None = None,
1555
alias: str | None = None,
1556
**kwargs,
1557
) -> Blocks | dict[str, Any] | None:
1558
"""
1559
For reverse compatibility reasons, this is both a class method and an instance
1560
method, the two of which, confusingly, do two completely different things.
1561
1562
1563
Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Warning: this method will be deprecated. Use the equivalent `gradio.load()` instead.
1564
1565
1566
Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
1567
Parameters:
1568
name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
1569
src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
1570
api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens. Warning: only provide this if you are loading a trusted private Space as it can be read by the Space you are loading.
1571
alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
1572
fn: Instance Method - the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
1573
inputs: Instance Method - List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
1574
outputs: Instance Method - List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
1575
api_name: Instance Method - Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
1576
scroll_to_output: Instance Method - If True, will scroll to output component on completion
1577
show_progress: Instance Method - If True, will show progress animation while pending
1578
queue: Instance Method - If True, will place the request on the queue, if the queue exists
1579
batch: Instance Method - If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
1580
max_batch_size: Instance Method - Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
1581
preprocess: Instance Method - If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
1582
postprocess: Instance Method - If False, will not run postprocessing of component data before returning 'fn' output to the browser.
1583
every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
1584
Example:
1585
import gradio as gr
1586
import datetime
1587
with gr.Blocks() as demo:
1588
def get_time():
1589
return datetime.datetime.now().time()
1590
dt = gr.Textbox(label="Current time")
1591
demo.load(get_time, inputs=None, outputs=dt)
1592
demo.launch()
1593
"""
1594
if isinstance(self_or_cls, type):
1595
warn_deprecation(
1596
"gr.Blocks.load() will be deprecated. Use gr.load() instead."
1597
)
1598
if name is None:
1599
raise ValueError(
1600
"Blocks.load() requires passing parameters as keyword arguments"
1601
)
1602
return external.load(
1603
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
1604
)
1605
else:
1606
from gradio.events import Dependency
1607
1608
dep, dep_index = self_or_cls.set_event_trigger(
1609
event_name="load",
1610
fn=fn,
1611
inputs=inputs,
1612
outputs=outputs,
1613
api_name=api_name,
1614
preprocess=preprocess,
1615
postprocess=postprocess,
1616
scroll_to_output=scroll_to_output,
1617
show_progress=show_progress,
1618
js=_js,
1619
queue=queue,
1620
batch=batch,
1621
max_batch_size=max_batch_size,
1622
every=every,
1623
no_target=True,
1624
)
1625
return Dependency(self_or_cls, dep, dep_index)
1626
1627
def clear(self):
1628
"""Resets the layout of the Blocks object."""
1629
self.blocks = {}
1630
self.fns = []
1631
self.dependencies = []
1632
self.children = []
1633
return self
1634
1635
@concurrency_count_warning
1636
@document()
1637
def queue(
1638
self,
1639
concurrency_count: int = 1,
1640
status_update_rate: float | Literal["auto"] = "auto",
1641
client_position_to_load_data: int | None = None,
1642
default_enabled: bool | None = None,
1643
api_open: bool = True,
1644
max_size: int | None = None,
1645
):
1646
"""
1647
By enabling the queue you can control the rate of processed requests, let users know their position in the queue, and set a limit on maximum number of events allowed.
1648
Parameters:
1649
concurrency_count: Number of worker threads that will be processing requests from the queue concurrently. Increasing this number will increase the rate at which requests are processed, but will also increase the memory usage of the queue.
1650
status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
1651
client_position_to_load_data: DEPRECATED. This parameter is deprecated and has no effect.
1652
default_enabled: Deprecated and has no effect.
1653
api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
1654
max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited.
1655
Example: (Blocks)
1656
with gr.Blocks() as demo:
1657
button = gr.Button(label="Generate Image")
1658
button.click(fn=image_generator, inputs=gr.Textbox(), outputs=gr.Image())
1659
demo.queue(max_size=10)
1660
demo.launch()
1661
Example: (Interface)
1662
demo = gr.Interface(image_generator, gr.Textbox(), gr.Image())
1663
demo.queue(max_size=20)
1664
demo.launch()
1665
"""
1666
if default_enabled is not None:
1667
warn_deprecation(
1668
"The default_enabled parameter of queue has no effect and will be removed "
1669
"in a future version of gradio."
1670
)
1671
self.enable_queue = True
1672
self.api_open = api_open
1673
if client_position_to_load_data is not None:
1674
warn_deprecation(
1675
"The client_position_to_load_data parameter is deprecated."
1676
)
1677
if utils.is_zero_gpu_space():
1678
concurrency_count = self.max_threads
1679
max_size = 1 if max_size is None else max_size
1680
self._queue = queueing.Queue(
1681
live_updates=status_update_rate == "auto",
1682
concurrency_count=concurrency_count,
1683
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
1684
max_size=max_size,
1685
blocks_dependencies=self.dependencies,
1686
)
1687
self.config = self.get_config_file()
1688
self.app = routes.App.create_app(self)
1689
return self
1690
1691
def validate_queue_settings(self):
1692
if not self.enable_queue and self.progress_tracking:
1693
raise ValueError("Progress tracking requires queuing to be enabled.")
1694
1695
for fn_index, dep in enumerate(self.dependencies):
1696
if not self.enable_queue and self.queue_enabled_for_fn(fn_index):
1697
raise ValueError(
1698
f"The queue is enabled for event {dep['api_name'] if dep['api_name'] else fn_index} "
1699
"but the queue has not been enabled for the app. Please call .queue() "
1700
"on your app. Consult https://gradio.app/docs/#blocks-queue for information on how "
1701
"to configure the queue."
1702
)
1703
for i in dep["cancels"]:
1704
if not self.queue_enabled_for_fn(i):
1705
raise ValueError(
1706
"Queue needs to be enabled! "
1707
"You may get this error by either 1) passing a function that uses the yield keyword "
1708
"into an interface without enabling the queue or 2) defining an event that cancels "
1709
"another event without enabling the queue. Both can be solved by calling .queue() "
1710
"before .launch()"
1711
)
1712
if dep["batch"] and (
1713
dep["queue"] is False
1714
or (dep["queue"] is None and not self.enable_queue)
1715
):
1716
raise ValueError("In order to use batching, the queue must be enabled.")
1717
1718
def launch(
1719
self,
1720
inline: bool | None = None,
1721
inbrowser: bool = False,
1722
share: bool | None = None,
1723
debug: bool = False,
1724
enable_queue: bool | None = None,
1725
max_threads: int = 40,
1726
auth: Callable | tuple[str, str] | list[tuple[str, str]] | None = None,
1727
auth_message: str | None = None,
1728
prevent_thread_lock: bool = False,
1729
show_error: bool = False,
1730
server_name: str | None = None,
1731
server_port: int | None = None,
1732
show_tips: bool = False,
1733
height: int = 500,
1734
width: int | str = "100%",
1735
encrypt: bool | None = None,
1736
favicon_path: str | None = None,
1737
ssl_keyfile: str | None = None,
1738
ssl_certfile: str | None = None,
1739
ssl_keyfile_password: str | None = None,
1740
ssl_verify: bool = True,
1741
quiet: bool = False,
1742
show_api: bool = True,
1743
file_directories: list[str] | None = None,
1744
allowed_paths: list[str] | None = None,
1745
blocked_paths: list[str] | None = None,
1746
root_path: str | None = None,
1747
_frontend: bool = True,
1748
app_kwargs: dict[str, Any] | None = None,
1749
) -> tuple[FastAPI, str, str]:
1750
"""
1751
Launches a simple web server that serves the demo. Can also be used to create a
1752
public link used by anyone to access the demo from their browser by setting share=True.
1753
1754
Parameters:
1755
inline: whether to display in the interface inline in an iframe. Defaults to True in python notebooks; False otherwise.
1756
inbrowser: whether to automatically launch the interface in a new tab on the default browser.
1757
share: whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. If not provided, it is set to False by default every time, except when running in Google Colab. When localhost is not accessible (e.g. Google Colab), setting share=False is not supported.
1758
debug: if True, blocks the main thread from running. If running in Google Colab, this is needed to print the errors in the cell output.
1759
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
1760
auth_message: If provided, HTML message provided on login page.
1761
prevent_thread_lock: If True, the interface will block the main thread while the server is running.
1762
show_error: If True, any errors in the interface will be displayed in an alert modal and printed in the browser console log
1763
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. If None, will search for an available port starting at 7860.
1764
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
1765
show_tips: if True, will occasionally show tips about new Gradio features
1766
enable_queue: DEPRECATED (use .queue() method instead.) if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
1767
max_threads: the maximum number of total threads that the Gradio app can generate in parallel. The default is inherited from the starlette library (currently 40). Applies whether the queue is enabled or not. But if queuing is enabled, this parameter is increaseed to be at least the concurrency_count of the queue.
1768
width: The width in pixels of the iframe element containing the interface (used if inline=True)
1769
height: The height in pixels of the iframe element containing the interface (used if inline=True)
1770
encrypt: DEPRECATED. Has no effect.
1771
favicon_path: If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
1772
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
1773
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
1774
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
1775
ssl_verify: If False, skips certificate validation which allows self-signed certificates to be used.
1776
quiet: If True, suppresses most print statements.
1777
show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
1778
file_directories: This parameter has been renamed to `allowed_paths`. It will be removed in a future version.
1779
allowed_paths: List of complete filepaths or parent directories that gradio is allowed to serve (in addition to the directory containing the gradio python file). Must be absolute paths. Warning: if you provide directories, any files in these directories or their subdirectories are accessible to all users of your app.
1780
blocked_paths: List of complete filepaths or parent directories that gradio is not allowed to serve (i.e. users of your app are not allowed to access). Must be absolute paths. Warning: takes precedence over `allowed_paths` and all other directories exposed by Gradio by default.
1781
root_path: The root path (or "mount point") of the application, if it's not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application. For example, if the application is served at "https://example.com/myapp", the `root_path` should be set to "/myapp". Can be set by environment variable GRADIO_ROOT_PATH. Defaults to "".
1782
app_kwargs: Additional keyword arguments to pass to the underlying FastAPI app as a dictionary of parameter keys and argument values. For example, `{"docs_url": "/docs"}`
1783
Returns:
1784
app: FastAPI app object that is running the demo
1785
local_url: Locally accessible link to the demo
1786
share_url: Publicly accessible link to the demo (if share=True, otherwise None)
1787
Example: (Blocks)
1788
import gradio as gr
1789
def reverse(text):
1790
return text[::-1]
1791
with gr.Blocks() as demo:
1792
button = gr.Button(value="Reverse")
1793
button.click(reverse, gr.Textbox(), gr.Textbox())
1794
demo.launch(share=True, auth=("username", "password"))
1795
Example: (Interface)
1796
import gradio as gr
1797
def reverse(text):
1798
return text[::-1]
1799
demo = gr.Interface(reverse, "text", "text")
1800
demo.launch(share=True, auth=("username", "password"))
1801
"""
1802
if not self.exited:
1803
self.__exit__()
1804
1805
self.dev_mode = False
1806
if (
1807
auth
1808
and not callable(auth)
1809
and not isinstance(auth[0], tuple)
1810
and not isinstance(auth[0], list)
1811
):
1812
self.auth = [auth]
1813
else:
1814
self.auth = auth
1815
self.auth_message = auth_message
1816
self.show_tips = show_tips
1817
self.show_error = show_error
1818
self.height = height
1819
self.width = width
1820
self.favicon_path = favicon_path
1821
self.ssl_verify = ssl_verify
1822
if root_path is None:
1823
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
1824
else:
1825
self.root_path = root_path
1826
1827
if enable_queue is not None:
1828
self.enable_queue = enable_queue
1829
warn_deprecation(
1830
"The `enable_queue` parameter has been deprecated. "
1831
"Please use the `.queue()` method instead.",
1832
)
1833
if encrypt is not None:
1834
warn_deprecation(
1835
"The `encrypt` parameter has been deprecated and has no effect.",
1836
)
1837
1838
if self.space_id:
1839
self.enable_queue = self.enable_queue is not False
1840
else:
1841
self.enable_queue = self.enable_queue is True
1842
if self.enable_queue and not hasattr(self, "_queue"):
1843
self.queue()
1844
self.show_api = self.api_open if self.enable_queue else show_api
1845
1846
if file_directories is not None:
1847
warn_deprecation(
1848
"The `file_directories` parameter has been renamed to `allowed_paths`. "
1849
"Please use that instead.",
1850
)
1851
if allowed_paths is None:
1852
allowed_paths = file_directories
1853
self.allowed_paths = allowed_paths or []
1854
self.blocked_paths = blocked_paths or []
1855
1856
if not isinstance(self.allowed_paths, list):
1857
raise ValueError("`allowed_paths` must be a list of directories.")
1858
if not isinstance(self.blocked_paths, list):
1859
raise ValueError("`blocked_paths` must be a list of directories.")
1860
1861
self.validate_queue_settings()
1862
1863
self.config = self.get_config_file()
1864
self.max_threads = max(
1865
self._queue.max_thread_count if self.enable_queue else 0, max_threads
1866
)
1867
1868
if self.is_running:
1869
assert isinstance(
1870
self.local_url, str
1871
), f"Invalid local_url: {self.local_url}"
1872
if not (quiet):
1873
print(
1874
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
1875
)
1876
else:
1877
if wasm_utils.IS_WASM:
1878
server_name = "xxx"
1879
server_port = 99999
1880
local_url = ""
1881
server = None
1882
1883
# In the Wasm environment, we only need the app object
1884
# which the frontend app will directly communicate with through the Worker API,
1885
# and we don't need to start a server.
1886
# So we just create the app object and register it here,
1887
# and avoid using `networking.start_server` that would start a server that don't work in the Wasm env.
1888
from gradio.routes import App
1889
1890
app = App.create_app(self, app_kwargs=app_kwargs)
1891
wasm_utils.register_app(app)
1892
else:
1893
(
1894
server_name,
1895
server_port,
1896
local_url,
1897
app,
1898
server,
1899
) = networking.start_server(
1900
self,
1901
server_name,
1902
server_port,
1903
ssl_keyfile,
1904
ssl_certfile,
1905
ssl_keyfile_password,
1906
app_kwargs=app_kwargs,
1907
)
1908
self.server_name = server_name
1909
self.local_url = local_url
1910
self.server_port = server_port
1911
self.server_app = app
1912
self.server = server
1913
self.is_running = True
1914
self.is_colab = utils.colab_check()
1915
self.is_kaggle = utils.kaggle_check()
1916
1917
self.protocol = (
1918
"https"
1919
)
1920
1921
if self.enable_queue:
1922
self._queue.set_url(self.local_url)
1923
1924
if not wasm_utils.IS_WASM:
1925
# Cannot run async functions in background other than app's scope.
1926
# Workaround by triggering the app endpoint
1927
requests.get(f"{self.local_url}startup-events", verify=ssl_verify)
1928
else:
1929
pass
1930
# TODO: Call the startup endpoint in the Wasm env too.
1931
1932
utils.launch_counter()
1933
self.is_sagemaker = utils.sagemaker_check()
1934
if share is None:
1935
if self.is_colab and self.enable_queue:
1936
if not quiet:
1937
print(
1938
"Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1939
)
1940
self.share = True
1941
elif self.is_kaggle:
1942
if not quiet:
1943
print(
1944
"Kaggle notebooks require sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1945
)
1946
self.share = True
1947
elif self.is_sagemaker:
1948
if not quiet:
1949
print(
1950
"Sagemaker notebooks may require sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n"
1951
)
1952
self.share = True
1953
else:
1954
self.share = False
1955
else:
1956
self.share = share
1957
1958
# If running in a colab or not able to access localhost,
1959
# a shareable link must be created.
1960
if (
1961
_frontend
1962
and not wasm_utils.IS_WASM
1963
and not networking.url_ok(self.local_url)
1964
and not self.share
1965
):
1966
raise ValueError(
1967
"When localhost is not accessible, a shareable link must be created. Please set share=True or check your proxy settings to allow access to localhost."
1968
)
1969
1970
if self.is_colab:
1971
if not quiet:
1972
if debug:
1973
print(strings.en["COLAB_DEBUG_TRUE"])
1974
else:
1975
print(strings.en["COLAB_DEBUG_FALSE"])
1976
if not self.share:
1977
print(strings.en["COLAB_WARNING"].format(self.server_port))
1978
if self.enable_queue and not self.share:
1979
raise ValueError(
1980
"When using queueing in Colab, a shareable link must be created. Please set share=True."
1981
)
1982
else:
1983
if not self.share:
1984
print(f'Running on local URL: https://{self.server_name}')
1985
1986
if self.share:
1987
if self.space_id:
1988
raise RuntimeError("Share is not supported when you are in Spaces")
1989
if wasm_utils.IS_WASM:
1990
raise RuntimeError("Share is not supported in the Wasm environment")
1991
try:
1992
if self.share_url is None:
1993
self.share_url = networking.setup_tunnel(
1994
self.server_name, self.server_port, self.share_token
1995
)
1996
print(strings.en["SHARE_LINK_DISPLAY"].format(self.share_url))
1997
if not (quiet):
1998
print('\u2714 Connected')
1999
except (RuntimeError, requests.exceptions.ConnectionError):
2000
if self.analytics_enabled:
2001
analytics.error_analytics("Not able to set up tunnel")
2002
self.share_url = None
2003
self.share = False
2004
if Path(BINARY_PATH).exists():
2005
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
2006
else:
2007
print(
2008
strings.en["COULD_NOT_GET_SHARE_LINK_MISSING_FILE"].format(
2009
BINARY_PATH,
2010
BINARY_URL,
2011
BINARY_FILENAME,
2012
BINARY_FOLDER,
2013
)
2014
)
2015
else:
2016
if not quiet and not wasm_utils.IS_WASM:
2017
print('\u2714 Connected')
2018
self.share_url = None
2019
2020
if inbrowser and not wasm_utils.IS_WASM:
2021
link = self.share_url if self.share and self.share_url else self.local_url
2022
webbrowser.open(link)
2023
2024
# Check if running in a Python notebook in which case, display inline
2025
if inline is None:
2026
inline = utils.ipython_check()
2027
if inline:
2028
try:
2029
from IPython.display import HTML, Javascript, display # type: ignore
2030
2031
if self.share and self.share_url:
2032
while not networking.url_ok(self.share_url):
2033
time.sleep(0.25)
2034
display(
2035
HTML(
2036
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
2037
)
2038
)
2039
elif self.is_colab:
2040
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
2041
code = """(async (port, path, width, height, cache, element) => {
2042
if (!google.colab.kernel.accessAllowed && !cache) {
2043
return;
2044
}
2045
element.appendChild(document.createTextNode(''));
2046
const url = await google.colab.kernel.proxyPort(port, {cache});
2047
2048
const external_link = document.createElement('div');
2049
external_link.innerHTML = `
2050
<div style="font-family: monospace; margin-bottom: 0.5rem">
2051
Running on <a href=${new URL(path, url).toString()} target="_blank">
2052
https://localhost:${port}${path}
2053
</a>
2054
</div>
2055
`;
2056
element.appendChild(external_link);
2057
2058
const iframe = document.createElement('iframe');
2059
iframe.src = new URL(path, url).toString();
2060
iframe.height = height;
2061
iframe.allow = "autoplay; camera; microphone; clipboard-read; clipboard-write;"
2062
iframe.width = width;
2063
iframe.style.border = 0;
2064
element.appendChild(iframe);
2065
})""" + "({port}, {path}, {width}, {height}, {cache}, window.element)".format(
2066
port=json.dumps(self.server_port),
2067
path=json.dumps("/"),
2068
width=json.dumps(self.width),
2069
height=json.dumps(self.height),
2070
cache=json.dumps(False),
2071
)
2072
2073
display(Javascript(code))
2074
else:
2075
display(
2076
HTML(
2077
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
2078
)
2079
)
2080
except ImportError:
2081
pass
2082
2083
if getattr(self, "analytics_enabled", False):
2084
data = {
2085
"launch_method": "browser" if inbrowser else "inline",
2086
"is_google_colab": self.is_colab,
2087
"is_sharing_on": self.share,
2088
"share_url": self.share_url,
2089
"enable_queue": self.enable_queue,
2090
"show_tips": self.show_tips,
2091
"server_name": server_name,
2092
"server_port": server_port,
2093
"is_space": self.space_id is not None,
2094
"mode": self.mode,
2095
}
2096
analytics.launched_analytics(self, data)
2097
2098
utils.show_tip(self)
2099
2100
# Block main thread if debug==True
2101
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1 and not wasm_utils.IS_WASM:
2102
self.block_thread()
2103
# Block main thread if running in a script to stop script from exiting
2104
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
2105
2106
if (
2107
not prevent_thread_lock
2108
and not is_in_interactive_mode
2109
# In the Wasm env, we don't have to block the main thread because the server won't be shut down after the execution finishes.
2110
# Moreover, we MUST NOT do it because there is only one thread in the Wasm env and blocking it will stop the subsequent code from running.
2111
and not wasm_utils.IS_WASM
2112
):
2113
self.block_thread()
2114
2115
return TupleNoPrint((self.server_app, self.local_url, self.share_url))
2116
2117
def integrate(
2118
self,
2119
comet_ml=None,
2120
wandb: ModuleType | None = None,
2121
mlflow: ModuleType | None = None,
2122
) -> None:
2123
"""
2124
A catch-all method for integrating with other libraries. This method should be run after launch()
2125
Parameters:
2126
comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
2127
wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
2128
mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
2129
"""
2130
analytics_integration = ""
2131
if comet_ml is not None:
2132
analytics_integration = "CometML"
2133
comet_ml.log_other("Created from", "Gradio")
2134
if self.share_url is not None:
2135
comet_ml.log_text(f"gradio: {self.share_url}")
2136
comet_ml.end()
2137
elif self.local_url:
2138
comet_ml.log_text(f"gradio: {self.local_url}")
2139
comet_ml.end()
2140
else:
2141
raise ValueError("Please run `launch()` first.")
2142
if wandb is not None:
2143
analytics_integration = "WandB"
2144
if self.share_url is not None:
2145
wandb.log(
2146
{
2147
"Gradio panel": wandb.Html(
2148
'<iframe src="'
2149
+ self.share_url
2150
+ '" width="'
2151
+ str(self.width)
2152
+ '" height="'
2153
+ str(self.height)
2154
+ '" frameBorder="0"></iframe>'
2155
)
2156
}
2157
)
2158
else:
2159
print(
2160
"The WandB integration requires you to "
2161
"`launch(share=True)` first."
2162
)
2163
if mlflow is not None:
2164
analytics_integration = "MLFlow"
2165
if self.share_url is not None:
2166
mlflow.log_param("Gradio Interface Share Link", self.share_url)
2167
else:
2168
mlflow.log_param("Gradio Interface Local Link", self.local_url)
2169
if self.analytics_enabled and analytics_integration:
2170
data = {"integration": analytics_integration}
2171
analytics.integration_analytics(data)
2172
2173
def close(self, verbose: bool = True) -> None:
2174
"""
2175
Closes the Interface that was launched and frees the port.
2176
"""
2177
try:
2178
if self.enable_queue:
2179
self._queue.close()
2180
if self.server:
2181
self.server.close()
2182
self.is_running = False
2183
# So that the startup events (starting the queue)
2184
# happen the next time the app is launched
2185
self.app.startup_events_triggered = False
2186
if verbose:
2187
print(f"Closing server running on port: {self.server_port}")
2188
except (AttributeError, OSError): # can't close if not running
2189
pass
2190
2191
def block_thread(
2192
self,
2193
) -> None:
2194
"""Block main thread until interrupted by user."""
2195
try:
2196
while True:
2197
time.sleep(0.1)
2198
except (KeyboardInterrupt, OSError):
2199
print("Keyboard interruption in main thread... closing server.")
2200
if self.server:
2201
self.server.close()
2202
for tunnel in CURRENT_TUNNELS:
2203
tunnel.kill()
2204
2205
def attach_load_events(self):
2206
"""Add a load event for every component whose initial value should be randomized."""
2207
if Context.root_block:
2208
for component in Context.root_block.blocks.values():
2209
if (
2210
isinstance(component, components.IOComponent)
2211
and component.load_event_to_attach
2212
):
2213
load_fn, every = component.load_event_to_attach
2214
# Use set_event_trigger to avoid ambiguity between load class/instance method
2215
dep = self.set_event_trigger(
2216
"load",
2217
load_fn,
2218
None,
2219
component,
2220
no_target=True,
2221
# If every is None, for sure skip the queue
2222
# else, let the enable_queue parameter take precedence
2223
# this will raise a nice error message is every is used
2224
# without queue
2225
queue=False if every is None else None,
2226
every=every,
2227
)[0]
2228
component.load_event = dep
2229
2230
def startup_events(self):
2231
"""Events that should be run when the app containing this block starts up."""
2232
2233
if self.enable_queue:
2234
utils.run_coro_in_background(self._queue.start, self.ssl_verify)
2235
# So that processing can resume in case the queue was stopped
2236
self._queue.stopped = False
2237
utils.run_coro_in_background(self.create_limiter)
2238
2239
def queue_enabled_for_fn(self, fn_index: int):
2240
if self.dependencies[fn_index]["queue"] is None:
2241
return self.enable_queue
2242
return self.dependencies[fn_index]["queue"]
2243
2244