Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/blocks.py
540 views
1
from __future__ import annotations
2
3
import copy
4
import getpass
5
import inspect
6
import json
7
import os
8
import pkgutil
9
import random
10
import sys
11
import time
12
import warnings
13
import webbrowser
14
from types import ModuleType
15
from typing import (
16
TYPE_CHECKING,
17
Any,
18
AnyStr,
19
Callable,
20
Dict,
21
Iterator,
22
List,
23
Optional,
24
Set,
25
Tuple,
26
)
27
28
import anyio
29
import requests
30
from anyio import CapacityLimiter
31
32
from gradio import (
33
components,
34
encryptor,
35
external,
36
networking,
37
queue,
38
routes,
39
strings,
40
utils,
41
)
42
from gradio.context import Context
43
from gradio.deprecation import check_deprecated_parameters
44
from gradio.documentation import (
45
document,
46
document_component_api,
47
set_documentation_group,
48
)
49
from gradio.exceptions import DuplicateBlockError, InvalidApiName
50
from gradio.utils import (
51
check_function_inputs_match,
52
component_or_layout_class,
53
delete_none,
54
get_cancel_function,
55
get_continuous_fn,
56
)
57
58
set_documentation_group("blocks")
59
60
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
61
import comet_ml
62
import mlflow
63
import wandb
64
from fastapi.applications import FastAPI
65
66
from gradio.components import Component, IOComponent
67
68
69
class Block:
70
def __init__(
71
self,
72
*,
73
render: bool = True,
74
elem_id: str | None = None,
75
visible: bool = True,
76
root_url: str | None = None, # URL that is prepended to all file paths
77
**kwargs,
78
):
79
self._id = Context.id
80
Context.id += 1
81
self.visible = visible
82
self.elem_id = elem_id
83
self.root_url = root_url
84
self._style = {}
85
if render:
86
self.render()
87
check_deprecated_parameters(self.__class__.__name__, **kwargs)
88
89
def render(self):
90
"""
91
Adds self into appropriate BlockContext
92
"""
93
if Context.root_block is not None and self._id in Context.root_block.blocks:
94
raise DuplicateBlockError(
95
f"A block with id: {self._id} has already been rendered in the current Blocks."
96
)
97
if Context.block is not None:
98
Context.block.add(self)
99
if Context.root_block is not None:
100
Context.root_block.blocks[self._id] = self
101
if hasattr(self, "temp_dir"):
102
Context.root_block.temp_dirs.add(self.temp_dir)
103
return self
104
105
def unrender(self):
106
"""
107
Removes self from BlockContext if it has been rendered (otherwise does nothing).
108
Removes self from the layout and collection of blocks, but does not delete any event triggers.
109
"""
110
if Context.block is not None:
111
try:
112
Context.block.children.remove(self)
113
except ValueError:
114
pass
115
if Context.root_block is not None:
116
try:
117
del Context.root_block.blocks[self._id]
118
except KeyError:
119
pass
120
return self
121
122
def get_block_name(self) -> str:
123
"""
124
Gets block's class name.
125
126
If it is template component it gets the parent's class name.
127
128
@return: class name
129
"""
130
return (
131
self.__class__.__base__.__name__.lower()
132
if hasattr(self, "is_template")
133
else self.__class__.__name__.lower()
134
)
135
136
def set_event_trigger(
137
self,
138
event_name: str,
139
fn: Callable | None,
140
inputs: Component | List[Component] | Set[Component] | None,
141
outputs: Component | List[Component] | None,
142
preprocess: bool = True,
143
postprocess: bool = True,
144
scroll_to_output: bool = False,
145
show_progress: bool = True,
146
api_name: AnyStr | None = None,
147
js: str | None = None,
148
no_target: bool = False,
149
queue: bool | None = None,
150
batch: bool = False,
151
max_batch_size: int = 4,
152
cancels: List[int] | None = None,
153
every: float | None = None,
154
) -> Dict[str, Any]:
155
"""
156
Adds an event to the component's dependencies.
157
Parameters:
158
event_name: event name
159
fn: Callable function
160
inputs: input list
161
outputs: output list
162
preprocess: whether to run the preprocess methods of components
163
postprocess: whether to run the postprocess methods of components
164
scroll_to_output: whether to scroll to output of dependency on trigger
165
show_progress: whether to show progress animation while running.
166
api_name: Defining this parameter exposes the endpoint in the api docs
167
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
168
no_target: if True, sets "targets" to [], used for Blocks "load" event
169
batch: whether this function takes in a batch of inputs
170
max_batch_size: the maximum batch size to send to the function
171
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
172
Returns: None
173
"""
174
# Support for singular parameter
175
if isinstance(inputs, set):
176
inputs_as_dict = True
177
inputs = sorted(inputs, key=lambda x: x._id)
178
else:
179
inputs_as_dict = False
180
if inputs is None:
181
inputs = []
182
elif not isinstance(inputs, list):
183
inputs = [inputs]
184
185
if isinstance(outputs, set):
186
outputs = sorted(outputs, key=lambda x: x._id)
187
else:
188
if outputs is None:
189
outputs = []
190
elif not isinstance(outputs, list):
191
outputs = [outputs]
192
193
if fn is not None and not cancels:
194
check_function_inputs_match(fn, inputs, inputs_as_dict)
195
196
if Context.root_block is None:
197
raise AttributeError(
198
f"{event_name}() and other events can only be called within a Blocks context."
199
)
200
if every is not None and every <= 0:
201
raise ValueError("Parameter every must be positive or None")
202
if every and batch:
203
raise ValueError(
204
f"Cannot run {event_name} event in a batch and every {every} seconds. "
205
"Either batch is True or every is non-zero but not both."
206
)
207
208
if every:
209
fn = get_continuous_fn(fn, every)
210
211
Context.root_block.fns.append(
212
BlockFunction(fn, inputs, outputs, preprocess, postprocess, inputs_as_dict)
213
)
214
if api_name is not None:
215
api_name_ = utils.append_unique_suffix(
216
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
217
)
218
if not (api_name == api_name_):
219
warnings.warn(
220
"api_name {} already exists, using {}".format(api_name, api_name_)
221
)
222
api_name = api_name_
223
224
dependency = {
225
"targets": [self._id] if not no_target else [],
226
"trigger": event_name,
227
"inputs": [block._id for block in inputs],
228
"outputs": [block._id for block in outputs],
229
"backend_fn": fn is not None,
230
"js": js,
231
"queue": False if fn is None else queue,
232
"api_name": api_name,
233
"scroll_to_output": scroll_to_output,
234
"show_progress": show_progress,
235
"every": every,
236
"batch": batch,
237
"max_batch_size": max_batch_size,
238
"cancels": cancels or [],
239
}
240
if api_name is not None:
241
dependency["documentation"] = [
242
[
243
document_component_api(component.__class__, "input")
244
for component in inputs
245
],
246
[
247
document_component_api(component.__class__, "output")
248
for component in outputs
249
],
250
]
251
Context.root_block.dependencies.append(dependency)
252
return dependency
253
254
def get_config(self):
255
return {
256
"visible": self.visible,
257
"elem_id": self.elem_id,
258
"style": self._style,
259
"root_url": self.root_url,
260
}
261
262
@classmethod
263
def get_specific_update(cls, generic_update):
264
del generic_update["__type__"]
265
generic_update = cls.update(**generic_update)
266
return generic_update
267
268
269
class BlockContext(Block):
270
def __init__(
271
self,
272
visible: bool = True,
273
render: bool = True,
274
**kwargs,
275
):
276
"""
277
Parameters:
278
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
279
render: If False, this will not be included in the Blocks config file at all.
280
"""
281
self.children = []
282
super().__init__(visible=visible, render=render, **kwargs)
283
284
def __enter__(self):
285
self.parent = Context.block
286
Context.block = self
287
return self
288
289
def add(self, child):
290
child.parent = self
291
self.children.append(child)
292
293
def fill_expected_parents(self):
294
children = []
295
pseudo_parent = None
296
for child in self.children:
297
expected_parent = getattr(child.__class__, "expected_parent", False)
298
if not expected_parent or isinstance(self, expected_parent):
299
pseudo_parent = None
300
children.append(child)
301
else:
302
if pseudo_parent is not None and isinstance(
303
pseudo_parent, expected_parent
304
):
305
pseudo_parent.children.append(child)
306
else:
307
pseudo_parent = expected_parent(render=False)
308
children.append(pseudo_parent)
309
pseudo_parent.children = [child]
310
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
311
child.parent = pseudo_parent
312
self.children = children
313
314
def __exit__(self, *args):
315
if getattr(self, "allow_expected_parents", True):
316
self.fill_expected_parents()
317
Context.block = self.parent
318
319
def postprocess(self, y):
320
"""
321
Any postprocessing needed to be performed on a block context.
322
"""
323
return y
324
325
326
class BlockFunction:
327
def __init__(
328
self,
329
fn: Optional[Callable],
330
inputs: List[Component],
331
outputs: List[Component],
332
preprocess: bool,
333
postprocess: bool,
334
inputs_as_dict: bool,
335
):
336
self.fn = fn
337
self.inputs = inputs
338
self.outputs = outputs
339
self.preprocess = preprocess
340
self.postprocess = postprocess
341
self.total_runtime = 0
342
self.total_runs = 0
343
self.inputs_as_dict = inputs_as_dict
344
345
def __str__(self):
346
return str(
347
{
348
"fn": getattr(self.fn, "__name__", "fn")
349
if self.fn is not None
350
else None,
351
"preprocess": self.preprocess,
352
"postprocess": self.postprocess,
353
}
354
)
355
356
def __repr__(self):
357
return str(self)
358
359
360
class class_or_instancemethod(classmethod):
361
def __get__(self, instance, type_):
362
descr_get = super().__get__ if instance is None else self.__func__.__get__
363
return descr_get(instance, type_)
364
365
366
@document()
367
def update(**kwargs) -> dict:
368
"""
369
Updates component properties.
370
This is a shorthand for using the update method on a component.
371
For example, rather than using gr.Number.update(...) you can just use gr.update(...).
372
Note that your editor's autocompletion will suggest proper parameters
373
if you use the update method on the component.
374
375
Demos: blocks_essay, blocks_update, blocks_essay_update
376
377
Parameters:
378
kwargs: Key-word arguments used to update the component's properties.
379
Example:
380
# Blocks Example
381
import gradio as gr
382
with gr.Blocks() as demo:
383
radio = gr.Radio([1, 2, 4], label="Set the value of the number")
384
number = gr.Number(value=2, interactive=True)
385
radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
386
demo.launch()
387
# Interface example
388
import gradio as gr
389
def change_textbox(choice):
390
if choice == "short":
391
return gr.Textbox.update(lines=2, visible=True)
392
elif choice == "long":
393
return gr.Textbox.update(lines=8, visible=True)
394
else:
395
return gr.Textbox.update(visible=False)
396
gr.Interface(
397
change_textbox,
398
gr.Radio(
399
["short", "long", "none"], label="What kind of essay would you like to write?"
400
),
401
gr.Textbox(lines=2),
402
live=True,
403
).launch()
404
"""
405
kwargs["__type__"] = "generic_update"
406
return kwargs
407
408
409
def skip() -> dict:
410
return update()
411
412
413
def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
414
"""
415
Converts a dictionary of updates into a format that can be sent to the frontend.
416
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
417
Into -> {"__type__": "update", "value": 2.0, "mode": "static"}
418
419
Parameters:
420
block: The Block that is being updated with this update dictionary.
421
update_dict: The original update dictionary
422
postprocess: Whether to postprocess the "value" key of the update dictionary.
423
"""
424
prediction_value = block.get_specific_update(update_dict)
425
if prediction_value.get("value") is components._Keywords.NO_VALUE:
426
prediction_value.pop("value")
427
prediction_value = delete_none(prediction_value, skip_value=True)
428
if "value" in prediction_value and postprocess:
429
prediction_value["value"] = block.postprocess(prediction_value["value"])
430
return prediction_value
431
432
433
def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
434
"""
435
Converts a dictionary of component updates into a list of updates in the order of
436
the outputs_ids and including every output component.
437
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
438
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
439
"""
440
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
441
if all(keys_are_blocks):
442
reordered_predictions = [skip() for _ in outputs_ids]
443
for component, value in predictions.items():
444
if component._id not in outputs_ids:
445
raise ValueError(
446
f"Returned component {component} not specified as output of function."
447
)
448
output_index = outputs_ids.index(component._id)
449
reordered_predictions[output_index] = value
450
predictions = utils.resolve_singleton(reordered_predictions)
451
elif any(keys_are_blocks):
452
raise ValueError(
453
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
454
)
455
return predictions
456
457
458
@document("load")
459
class Blocks(BlockContext):
460
"""
461
Blocks is Gradio's low-level API that allows you to create more custom web
462
applications and demos than Interfaces (yet still entirely in Python).
463
464
465
Compared to the Interface class, Blocks offers more flexibility and control over:
466
(1) the layout of components (2) the events that
467
trigger the execution of functions (3) data flows (e.g. inputs can trigger outputs,
468
which can trigger the next level of outputs). Blocks also offers ways to group
469
together related demos such as with tabs.
470
471
472
The basic usage of Blocks is as follows: create a Blocks object, then use it as a
473
context (with the "with" statement), and then define layouts, components, or events
474
within the Blocks context. Finally, call the launch() method to launch the demo.
475
476
Example:
477
import gradio as gr
478
def update(name):
479
return f"Welcome to Gradio, {name}!"
480
481
with gr.Blocks() as demo:
482
gr.Markdown("Start typing below and then click **Run** to see the output.")
483
with gr.Row():
484
inp = gr.Textbox(placeholder="What is your name?")
485
out = gr.Textbox()
486
btn = gr.Button("Run")
487
btn.click(fn=update, inputs=inp, outputs=out)
488
489
demo.launch()
490
Demos: blocks_hello, blocks_flipper, blocks_speech_text_sentiment, generate_english_german, sound_alert
491
Guides: blocks_and_event_listeners, controlling_layout, state_in_blocks, custom_CSS_and_JS, custom_interpretations_with_blocks, using_blocks_like_functions
492
"""
493
494
def __init__(
495
self,
496
theme: str = "default",
497
analytics_enabled: Optional[bool] = None,
498
mode: str = "blocks",
499
title: str = "Gradio",
500
css: Optional[str] = None,
501
**kwargs,
502
):
503
"""
504
Parameters:
505
theme: which theme to use - right now, only "default" is supported.
506
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
507
mode: a human-friendly name for the kind of Blocks interface being created.
508
title: The tab title to display when this is opened in a browser window.
509
css: custom css or path to custom css file to apply to entire Blocks
510
"""
511
# Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
512
self.limiter = None
513
self.save_to = None
514
self.theme = theme
515
self.requires_permissions = False # TODO: needs to be implemented
516
self.encrypt = False
517
self.share = False
518
self.enable_queue = None
519
self.max_threads = 40
520
self.show_error = True
521
if css is not None and os.path.exists(css):
522
with open(css) as css_file:
523
self.css = css_file.read()
524
else:
525
self.css = css
526
527
# For analytics_enabled and allow_flagging: (1) first check for
528
# parameter, (2) check for env variable, (3) default to True/"manual"
529
self.analytics_enabled = (
530
analytics_enabled
531
if analytics_enabled is not None
532
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
533
)
534
535
super().__init__(render=False, **kwargs)
536
self.blocks: Dict[int, Block] = {}
537
self.fns: List[BlockFunction] = []
538
self.dependencies = []
539
self.mode = mode
540
541
self.is_running = False
542
self.local_url = None
543
self.share_url = None
544
self.width = None
545
self.height = None
546
self.api_open = True
547
548
self.ip_address = None
549
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
550
self.favicon_path = None
551
self.auth = None
552
self.dev_mode = True
553
self.app_id = random.getrandbits(64)
554
self.temp_dirs = set()
555
self.title = title
556
self.show_api = True
557
558
if self.analytics_enabled:
559
self.ip_address = utils.get_local_ip_address()
560
data = {
561
"mode": self.mode,
562
"ip_address": self.ip_address,
563
"custom_css": self.css is not None,
564
"theme": self.theme,
565
"version": pkgutil.get_data(__name__, "version.txt")
566
.decode("ascii")
567
.strip(),
568
}
569
utils.initiated_analytics(data)
570
571
@classmethod
572
def from_config(
573
cls, config: dict, fns: List[Callable], root_url: str | None = None
574
) -> Blocks:
575
"""
576
Factory method that creates a Blocks from a config and list of functions.
577
578
Parameters:
579
config: a dictionary containing the configuration of the Blocks.
580
fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config.
581
root_url: an optional root url to use for the components in the Blocks. Allows serving files from an external URL.
582
"""
583
config = copy.deepcopy(config)
584
components_config = config["components"]
585
original_mapping: Dict[int, Block] = {}
586
587
def get_block_instance(id: int) -> Block:
588
for block_config in components_config:
589
if block_config["id"] == id:
590
break
591
else:
592
raise ValueError("Cannot find block with id {}".format(id))
593
cls = component_or_layout_class(block_config["type"])
594
block_config["props"].pop("type", None)
595
block_config["props"].pop("name", None)
596
style = block_config["props"].pop("style", None)
597
if block_config["props"].get("root_url") is None and root_url:
598
block_config["props"]["root_url"] = root_url + "/"
599
block = cls(**block_config["props"])
600
if style:
601
block.style(**style)
602
return block
603
604
def iterate_over_children(children_list):
605
for child_config in children_list:
606
id = child_config["id"]
607
block = get_block_instance(id)
608
original_mapping[id] = block
609
610
children = child_config.get("children")
611
if children is not None:
612
with block:
613
iterate_over_children(children)
614
615
with Blocks(theme=config["theme"], css=config["theme"]) as blocks:
616
# ID 0 should be the root Blocks component
617
original_mapping[0] = Context.root_block or blocks
618
619
iterate_over_children(config["layout"]["children"])
620
621
first_dependency = None
622
623
# add the event triggers
624
for dependency, fn in zip(config["dependencies"], fns):
625
targets = dependency.pop("targets")
626
trigger = dependency.pop("trigger")
627
dependency.pop("backend_fn")
628
dependency.pop("documentation", None)
629
dependency["inputs"] = [
630
original_mapping[i] for i in dependency["inputs"]
631
]
632
dependency["outputs"] = [
633
original_mapping[o] for o in dependency["outputs"]
634
]
635
dependency.pop("status_tracker", None)
636
dependency["preprocess"] = False
637
dependency["postprocess"] = False
638
639
for target in targets:
640
dependency = original_mapping[target].set_event_trigger(
641
event_name=trigger, fn=fn, **dependency
642
)
643
if first_dependency is None:
644
first_dependency = dependency
645
646
# Allows some use of Interface-specific methods with loaded Spaces
647
blocks.predict = [fns[0]]
648
blocks.input_components = [
649
Context.root_block.blocks[i] for i in first_dependency["inputs"]
650
]
651
blocks.output_components = [
652
Context.root_block.blocks[o] for o in first_dependency["outputs"]
653
]
654
655
if config.get("mode", "blocks") == "interface":
656
blocks.__name__ = "Interface"
657
blocks.mode = "interface"
658
blocks.api_mode = True
659
660
return blocks
661
662
def __str__(self):
663
return self.__repr__()
664
665
def __repr__(self):
666
num_backend_fns = len([d for d in self.dependencies if d["backend_fn"]])
667
repr = f"Gradio Blocks instance: {num_backend_fns} backend functions"
668
repr += "\n" + "-" * len(repr)
669
for d, dependency in enumerate(self.dependencies):
670
if dependency["backend_fn"]:
671
repr += f"\nfn_index={d}"
672
repr += "\n inputs:"
673
for input_id in dependency["inputs"]:
674
block = self.blocks[input_id]
675
repr += "\n |-{}".format(str(block))
676
repr += "\n outputs:"
677
for output_id in dependency["outputs"]:
678
block = self.blocks[output_id]
679
repr += "\n |-{}".format(str(block))
680
return repr
681
682
def render(self):
683
if Context.root_block is not None:
684
if self._id in Context.root_block.blocks:
685
raise DuplicateBlockError(
686
f"A block with id: {self._id} has already been rendered in the current Blocks."
687
)
688
if not set(Context.root_block.blocks).isdisjoint(self.blocks):
689
raise DuplicateBlockError(
690
"At least one block in this Blocks has already been rendered."
691
)
692
693
Context.root_block.blocks.update(self.blocks)
694
Context.root_block.fns.extend(self.fns)
695
dependency_offset = len(Context.root_block.dependencies)
696
for i, dependency in enumerate(self.dependencies):
697
api_name = dependency["api_name"]
698
if api_name is not None:
699
api_name_ = utils.append_unique_suffix(
700
api_name,
701
[dep["api_name"] for dep in Context.root_block.dependencies],
702
)
703
if not (api_name == api_name_):
704
warnings.warn(
705
"api_name {} already exists, using {}".format(
706
api_name, api_name_
707
)
708
)
709
dependency["api_name"] = api_name_
710
dependency["cancels"] = [
711
c + dependency_offset for c in dependency["cancels"]
712
]
713
# Recreate the cancel function so that it has the latest
714
# dependency fn indices. This is necessary to properly cancel
715
# events in the backend
716
if dependency["cancels"]:
717
updated_cancels = [
718
Context.root_block.dependencies[i]
719
for i in dependency["cancels"]
720
]
721
new_fn = BlockFunction(
722
get_cancel_function(updated_cancels)[0],
723
[],
724
[],
725
False,
726
True,
727
False,
728
)
729
Context.root_block.fns[dependency_offset + i] = new_fn
730
Context.root_block.dependencies.append(dependency)
731
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
732
733
if Context.block is not None:
734
Context.block.children.extend(self.children)
735
return self
736
737
def is_callable(self, fn_index: int = 0) -> bool:
738
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
739
block_fn = self.fns[fn_index]
740
dependency = self.dependencies[fn_index]
741
742
if inspect.isasyncgenfunction(block_fn.fn):
743
return False
744
if inspect.isgeneratorfunction(block_fn.fn):
745
raise False
746
for input_id in dependency["inputs"]:
747
block = self.blocks[input_id]
748
if getattr(block, "stateful", False):
749
return False
750
for output_id in dependency["outputs"]:
751
block = self.blocks[output_id]
752
if getattr(block, "stateful", False):
753
return False
754
755
return True
756
757
def __call__(self, *inputs, fn_index: int = 0, api_name: str = None):
758
"""
759
Allows Blocks objects to be called as functions. Supply the parameters to the
760
function as positional arguments. To choose which function to call, use the
761
fn_index parameter, which must be a keyword argument.
762
763
Parameters:
764
*inputs: the parameters to pass to the function
765
fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
766
api_name: The api_name of the dependency to call. Will take precedence over fn_index.
767
"""
768
if api_name is not None:
769
fn_index = next(
770
(
771
i
772
for i, d in enumerate(self.dependencies)
773
if d.get("api_name") == api_name
774
),
775
None,
776
)
777
if fn_index is None:
778
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
779
if not (self.is_callable(fn_index)):
780
raise ValueError(
781
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
782
)
783
784
inputs = list(inputs)
785
processed_inputs = self.serialize_data(fn_index, inputs)
786
batch = self.dependencies[fn_index]["batch"]
787
if batch:
788
processed_inputs = [[inp] for inp in processed_inputs]
789
790
outputs = utils.synchronize_async(self.process_api, fn_index, processed_inputs)
791
outputs = outputs["data"]
792
793
if batch:
794
outputs = [out[0] for out in outputs]
795
796
processed_outputs = self.deserialize_data(fn_index, outputs)
797
processed_outputs = utils.resolve_singleton(processed_outputs)
798
799
return processed_outputs
800
801
async def call_function(
802
self,
803
fn_index: int,
804
processed_input: List[Any],
805
iterator: Iterator[Any] | None = None,
806
):
807
"""Calls and times function with given index and preprocessed input."""
808
block_fn = self.fns[fn_index]
809
is_generating = False
810
start = time.time()
811
812
if block_fn.inputs_as_dict:
813
processed_input = [
814
{
815
input_component: data
816
for input_component, data in zip(block_fn.inputs, processed_input)
817
}
818
]
819
820
if iterator is None: # If not a generator function that has already run
821
if inspect.iscoroutinefunction(block_fn.fn):
822
prediction = await block_fn.fn(*processed_input)
823
else:
824
prediction = await anyio.to_thread.run_sync(
825
block_fn.fn, *processed_input, limiter=self.limiter
826
)
827
828
if inspect.isasyncgenfunction(block_fn.fn):
829
raise ValueError("Gradio does not support async generators.")
830
if inspect.isgeneratorfunction(block_fn.fn):
831
if not self.enable_queue:
832
raise ValueError("Need to enable queue to use generators.")
833
try:
834
if iterator is None:
835
iterator = prediction
836
prediction = await anyio.to_thread.run_sync(
837
utils.async_iteration, iterator, limiter=self.limiter
838
)
839
is_generating = True
840
except StopAsyncIteration:
841
n_outputs = len(self.dependencies[fn_index].get("outputs"))
842
prediction = (
843
components._Keywords.FINISHED_ITERATING
844
if n_outputs == 1
845
else (components._Keywords.FINISHED_ITERATING,) * n_outputs
846
)
847
iterator = None
848
849
duration = time.time() - start
850
851
return {
852
"prediction": prediction,
853
"duration": duration,
854
"is_generating": is_generating,
855
"iterator": iterator,
856
}
857
858
def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
859
dependency = self.dependencies[fn_index]
860
processed_input = []
861
862
for i, input_id in enumerate(dependency["inputs"]):
863
block: IOComponent = self.blocks[input_id]
864
serialized_input = block.serialize(inputs[i])
865
processed_input.append(serialized_input)
866
867
return processed_input
868
869
def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
870
dependency = self.dependencies[fn_index]
871
predictions = []
872
873
for o, output_id in enumerate(dependency["outputs"]):
874
block: IOComponent = self.blocks[output_id]
875
deserialized = block.deserialize(outputs[o])
876
predictions.append(deserialized)
877
878
return predictions
879
880
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
881
block_fn = self.fns[fn_index]
882
dependency = self.dependencies[fn_index]
883
884
if block_fn.preprocess:
885
processed_input = []
886
for i, input_id in enumerate(dependency["inputs"]):
887
block: IOComponent = self.blocks[input_id]
888
if getattr(block, "stateful", False):
889
processed_input.append(state.get(input_id))
890
else:
891
processed_input.append(block.preprocess(inputs[i]))
892
else:
893
processed_input = inputs
894
return processed_input
895
896
def postprocess_data(
897
self, fn_index: int, predictions: List[Any], state: Dict[int, Any]
898
):
899
block_fn = self.fns[fn_index]
900
dependency = self.dependencies[fn_index]
901
batch = dependency["batch"]
902
903
if type(predictions) is dict and len(predictions) > 0:
904
predictions = convert_component_dict_to_list(
905
dependency["outputs"], predictions
906
)
907
908
if len(dependency["outputs"]) == 1 and not (batch):
909
predictions = (predictions,)
910
911
output = []
912
for i, output_id in enumerate(dependency["outputs"]):
913
if predictions[i] is components._Keywords.FINISHED_ITERATING:
914
output.append(None)
915
continue
916
block = self.blocks[output_id]
917
if getattr(block, "stateful", False):
918
if not utils.is_update(predictions[i]):
919
state[output_id] = predictions[i]
920
output.append(None)
921
else:
922
prediction_value = predictions[i]
923
if utils.is_update(prediction_value):
924
prediction_value = postprocess_update_dict(
925
block=block,
926
update_dict=prediction_value,
927
postprocess=block_fn.postprocess,
928
)
929
elif block_fn.postprocess:
930
prediction_value = block.postprocess(prediction_value)
931
output.append(prediction_value)
932
return output
933
934
async def process_api(
935
self,
936
fn_index: int,
937
inputs: List[Any],
938
username: str = None,
939
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
940
iterators: Dict[int, Any] | None = None,
941
) -> Dict[str, Any]:
942
"""
943
Processes API calls from the frontend. First preprocesses the data,
944
then runs the relevant function, then postprocesses the output.
945
Parameters:
946
fn_index: Index of function to run.
947
inputs: input data received from the frontend
948
username: name of user if authentication is set up (not used)
949
state: data stored from stateful components for session (key is input block id)
950
iterators: the in-progress iterators for each generator function (key is function index)
951
Returns: None
952
"""
953
block_fn = self.fns[fn_index]
954
batch = self.dependencies[fn_index]["batch"]
955
956
if batch:
957
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
958
batch_sizes = [len(inp) for inp in inputs]
959
batch_size = batch_sizes[0]
960
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
961
block_fn.fn
962
):
963
raise ValueError("Gradio does not support generators in batch mode.")
964
if not all(x == batch_size for x in batch_sizes):
965
raise ValueError(
966
f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
967
)
968
if batch_size > max_batch_size:
969
raise ValueError(
970
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
971
)
972
973
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
974
result = await self.call_function(fn_index, zip(*inputs), None)
975
preds = result["prediction"]
976
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
977
data = list(zip(*data))
978
is_generating, iterator = None, None
979
else:
980
inputs = self.preprocess_data(fn_index, inputs, state)
981
iterator = iterators.get(fn_index, None) if iterators else None
982
result = await self.call_function(fn_index, inputs, iterator)
983
data = self.postprocess_data(fn_index, result["prediction"], state)
984
is_generating, iterator = result["is_generating"], result["iterator"]
985
986
block_fn.total_runtime += result["duration"]
987
block_fn.total_runs += 1
988
989
return {
990
"data": data,
991
"is_generating": is_generating,
992
"iterator": iterator,
993
"duration": result["duration"],
994
"average_duration": block_fn.total_runtime / block_fn.total_runs,
995
}
996
997
async def create_limiter(self):
998
self.limiter = (
999
None
1000
if self.max_threads == 40
1001
else CapacityLimiter(total_tokens=self.max_threads)
1002
)
1003
1004
def get_config(self):
1005
return {"type": "column"}
1006
1007
def get_config_file(self):
1008
config = {
1009
"version": routes.VERSION,
1010
"mode": self.mode,
1011
"dev_mode": self.dev_mode,
1012
"components": [],
1013
"theme": self.theme,
1014
"css": self.css,
1015
"title": self.title or "Gradio",
1016
"is_space": self.is_space,
1017
"enable_queue": getattr(self, "enable_queue", False), # launch attributes
1018
"show_error": getattr(self, "show_error", False),
1019
"show_api": self.show_api,
1020
"is_colab": utils.colab_check(),
1021
}
1022
1023
def getLayout(block):
1024
if not isinstance(block, BlockContext):
1025
return {"id": block._id}
1026
children_layout = []
1027
for child in block.children:
1028
children_layout.append(getLayout(child))
1029
return {"id": block._id, "children": children_layout}
1030
1031
config["layout"] = getLayout(self)
1032
1033
for _id, block in self.blocks.items():
1034
config["components"].append(
1035
{
1036
"id": _id,
1037
"type": (block.get_block_name()),
1038
"props": utils.delete_none(block.get_config())
1039
if hasattr(block, "get_config")
1040
else {},
1041
}
1042
)
1043
config["dependencies"] = self.dependencies
1044
return config
1045
1046
def __enter__(self):
1047
if Context.block is None:
1048
Context.root_block = self
1049
self.parent = Context.block
1050
Context.block = self
1051
return self
1052
1053
def __exit__(self, *args):
1054
super().fill_expected_parents()
1055
Context.block = self.parent
1056
# Configure the load events before root_block is reset
1057
self.attach_load_events()
1058
if self.parent is None:
1059
Context.root_block = None
1060
else:
1061
self.parent.children.extend(self.children)
1062
self.config = self.get_config_file()
1063
self.app = routes.App.create_app(self)
1064
1065
@class_or_instancemethod
1066
def load(
1067
self_or_cls,
1068
fn: Optional[Callable] = None,
1069
inputs: Optional[List[Component]] = None,
1070
outputs: Optional[List[Component]] = None,
1071
*,
1072
name: Optional[str] = None,
1073
src: Optional[str] = None,
1074
api_key: Optional[str] = None,
1075
alias: Optional[str] = None,
1076
_js: Optional[str] = None,
1077
every: None | int = None,
1078
**kwargs,
1079
) -> Blocks | Dict[str, Any] | None:
1080
"""
1081
For reverse compatibility reasons, this is both a class method and an instance
1082
method, the two of which, confusingly, do two completely different things.
1083
1084
1085
Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Equivalent to gradio.Interface.load()
1086
1087
1088
Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
1089
Parameters:
1090
name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
1091
src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
1092
api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
1093
alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
1094
fn: Instance Method - Callable function
1095
inputs: Instance Method - input list
1096
outputs: Instance Method - output list
1097
every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
1098
Example:
1099
import gradio as gr
1100
import datetime
1101
with gr.Blocks() as demo:
1102
def get_time():
1103
return datetime.datetime.now().time()
1104
dt = gr.Textbox(label="Current time")
1105
demo.load(get_time, inputs=None, outputs=dt)
1106
demo.launch()
1107
"""
1108
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
1109
if isinstance(self_or_cls, type):
1110
if name is None:
1111
raise ValueError(
1112
"Blocks.load() requires passing parameters as keyword arguments"
1113
)
1114
return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs)
1115
else:
1116
return self_or_cls.set_event_trigger(
1117
event_name="load",
1118
fn=fn,
1119
inputs=inputs,
1120
outputs=outputs,
1121
js=_js,
1122
no_target=True,
1123
every=every,
1124
)
1125
1126
def clear(self):
1127
"""Resets the layout of the Blocks object."""
1128
self.blocks = {}
1129
self.fns = []
1130
self.dependencies = []
1131
self.children = []
1132
return self
1133
1134
@document()
1135
def queue(
1136
self,
1137
concurrency_count: int = 1,
1138
status_update_rate: float | str = "auto",
1139
client_position_to_load_data: int = 30,
1140
default_enabled: bool = True,
1141
api_open: bool = True,
1142
max_size: Optional[int] = None,
1143
):
1144
"""
1145
You can control the rate of processed requests by creating a queue. This will allow you to set the number of requests to be processed at one time, and will let users know their position in the queue.
1146
Parameters:
1147
concurrency_count: Number of worker threads that will be processing requests concurrently.
1148
status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
1149
client_position_to_load_data: Once a client's position in Queue is less that this value, the Queue will collect the input data from the client. You may make this smaller if clients can send large volumes of data, such as video, since the queued data is stored in memory.
1150
default_enabled: If True, all event listeners will use queueing by default.
1151
api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
1152
max_size: The maximum number of events the queue will store at any given moment.
1153
Example:
1154
demo = gr.Interface(gr.Textbox(), gr.Image(), image_generator)
1155
demo.queue(concurrency_count=3)
1156
demo.launch()
1157
"""
1158
self.enable_queue = default_enabled
1159
self.api_open = api_open
1160
self._queue = queue.Queue(
1161
live_updates=status_update_rate == "auto",
1162
concurrency_count=concurrency_count,
1163
data_gathering_start=client_position_to_load_data,
1164
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
1165
max_size=max_size,
1166
blocks_dependencies=self.dependencies,
1167
)
1168
self.config = self.get_config_file()
1169
return self
1170
1171
def launch(
1172
self,
1173
inline: bool = None,
1174
inbrowser: bool = False,
1175
share: Optional[bool] = None,
1176
debug: bool = False,
1177
enable_queue: bool = None,
1178
max_threads: int = 40,
1179
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
1180
auth_message: Optional[str] = None,
1181
prevent_thread_lock: bool = False,
1182
show_error: bool = False,
1183
server_name: Optional[str] = None,
1184
server_port: Optional[int] = None,
1185
show_tips: bool = False,
1186
height: int = 500,
1187
width: int | str = "100%",
1188
encrypt: bool = False,
1189
favicon_path: Optional[str] = None,
1190
ssl_keyfile: Optional[str] = None,
1191
ssl_certfile: Optional[str] = None,
1192
ssl_keyfile_password: Optional[str] = None,
1193
quiet: bool = False,
1194
show_api: bool = True,
1195
_frontend: bool = True,
1196
) -> Tuple[FastAPI, str, str]:
1197
"""
1198
Launches a simple web server that serves the demo. Can also be used to create a
1199
public link used by anyone to access the demo from their browser by setting share=True.
1200
1201
Parameters:
1202
inline: whether to display in the interface inline in an iframe. Defaults to True in python notebooks; False otherwise.
1203
inbrowser: whether to automatically launch the interface in a new tab on the default browser.
1204
share: whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. If not provided, it is set to False by default every time, except when running in Google Colab. When localhost is not accessible (e.g. Google Colab), setting share=False is not supported.
1205
debug: if True, blocks the main thread from running. If running in Google Colab, this is needed to print the errors in the cell output.
1206
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
1207
auth_message: If provided, HTML message provided on login page.
1208
prevent_thread_lock: If True, the interface will block the main thread while the server is running.
1209
show_error: If True, any errors in the interface will be displayed in an alert modal and printed in the browser console log
1210
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. If None, will search for an available port starting at 7860.
1211
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. If None, will use "127.0.0.1".
1212
show_tips: if True, will occasionally show tips about new Gradio features
1213
enable_queue: DEPRECATED (use .queue() method instead.) if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
1214
max_threads: allow up to `max_threads` to be processed in parallel. The default is inherited from the starlette library (currently 40).
1215
width: The width in pixels of the iframe element containing the interface (used if inline=True)
1216
height: The height in pixels of the iframe element containing the interface (used if inline=True)
1217
encrypt: If True, flagged data will be encrypted by key provided by creator at launch
1218
favicon_path: If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
1219
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
1220
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
1221
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
1222
quiet: If True, suppresses most print statements.
1223
show_api: If True, shows the api docs in the footer of the app. Default True. If the queue is enabled, then api_open parameter of .queue() will determine if the api docs are shown, independent of the value of show_api.
1224
Returns:
1225
app: FastAPI app object that is running the demo
1226
local_url: Locally accessible link to the demo
1227
share_url: Publicly accessible link to the demo (if share=True, otherwise None)
1228
Example:
1229
import gradio as gr
1230
def reverse(text):
1231
return text[::-1]
1232
demo = gr.Interface(reverse, "text", "text")
1233
demo.launch(share=True, auth=("username", "password"))
1234
"""
1235
self.dev_mode = False
1236
if (
1237
auth
1238
and not callable(auth)
1239
and not isinstance(auth[0], tuple)
1240
and not isinstance(auth[0], list)
1241
):
1242
auth = [auth]
1243
self.auth = auth
1244
self.auth_message = auth_message
1245
self.show_tips = show_tips
1246
self.show_error = show_error
1247
self.height = height
1248
self.width = width
1249
self.favicon_path = favicon_path
1250
if enable_queue is not None:
1251
self.enable_queue = enable_queue
1252
warnings.warn(
1253
"The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
1254
DeprecationWarning,
1255
)
1256
1257
if self.is_space:
1258
self.enable_queue = self.enable_queue is not False
1259
else:
1260
self.enable_queue = self.enable_queue is True
1261
if self.enable_queue and not hasattr(self, "_queue"):
1262
self.queue()
1263
self.show_api = self.api_open if self.enable_queue else show_api
1264
1265
for dep in self.dependencies:
1266
for i in dep["cancels"]:
1267
if not self.queue_enabled_for_fn(i):
1268
raise ValueError(
1269
"In order to cancel an event, the queue for that event must be enabled! "
1270
"You may get this error by either 1) passing a function that uses the yield keyword "
1271
"into an interface without enabling the queue or 2) defining an event that cancels "
1272
"another event without enabling the queue. Both can be solved by calling .queue() "
1273
"before .launch()"
1274
)
1275
if dep["batch"] and (
1276
dep["queue"] is False
1277
or (dep["queue"] is None and not self.enable_queue)
1278
):
1279
raise ValueError("In order to use batching, the queue must be enabled.")
1280
1281
self.config = self.get_config_file()
1282
self.encrypt = encrypt
1283
self.max_threads = max(
1284
self._queue.max_thread_count if self.enable_queue else 0, max_threads
1285
)
1286
if self.encrypt:
1287
self.encryption_key = encryptor.get_key(
1288
getpass.getpass("Enter key for encryption: ")
1289
)
1290
1291
if self.is_running:
1292
self.server_app.launchable = self
1293
if not (quiet):
1294
print(
1295
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
1296
)
1297
else:
1298
server_name, server_port, local_url, app, server = networking.start_server(
1299
self,
1300
server_name,
1301
server_port,
1302
ssl_keyfile,
1303
ssl_certfile,
1304
ssl_keyfile_password,
1305
)
1306
self.server_name = server_name
1307
self.local_url = local_url
1308
self.server_port = server_port
1309
self.server_app = app
1310
self.server = server
1311
self.is_running = True
1312
self.is_colab = utils.colab_check()
1313
self.protocol = (
1314
"https"
1315
if self.local_url.startswith("https") or self.is_colab
1316
else "http"
1317
)
1318
1319
if self.enable_queue:
1320
self._queue.set_url(self.local_url)
1321
1322
# Cannot run async functions in background other than app's scope.
1323
# Workaround by triggering the app endpoint
1324
requests.get(f"{self.local_url}startup-events")
1325
1326
if self.enable_queue:
1327
if self.auth is not None or self.encrypt:
1328
raise ValueError(
1329
"Cannot queue with encryption or authentication enabled."
1330
)
1331
utils.launch_counter()
1332
1333
self.share = (
1334
share
1335
if share is not None
1336
else True
1337
if self.is_colab and self.enable_queue
1338
else False
1339
)
1340
1341
# If running in a colab or not able to access localhost,
1342
# a shareable link must be created.
1343
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
1344
raise ValueError(
1345
"When localhost is not accessible, a shareable link must be created. Please set share=True."
1346
)
1347
1348
if self.is_colab:
1349
if not quiet:
1350
if debug:
1351
print(strings.en["COLAB_DEBUG_TRUE"])
1352
else:
1353
print(strings.en["COLAB_DEBUG_FALSE"])
1354
if not self.share:
1355
print(strings.en["COLAB_BETA"].format(self.server_port))
1356
if self.enable_queue and not self.share:
1357
raise ValueError(
1358
"When using queueing in Colab, a shareable link must be created. Please set share=True."
1359
)
1360
else:
1361
if not self.share:
1362
print(f'Running on local URL: https://{self.server_name}')
1363
1364
if self.share:
1365
if self.is_space:
1366
raise RuntimeError("Share is not supported when you are in Spaces")
1367
try:
1368
if self.share_url is None:
1369
share_url = networking.setup_tunnel(self.server_port, None)
1370
self.share_url = share_url
1371
print(strings.en["SHARE_LINK_DISPLAY"].format(self.share_url))
1372
if not (quiet):
1373
print('\u2714 Connected')
1374
except RuntimeError:
1375
if self.analytics_enabled:
1376
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
1377
self.share_url = None
1378
self.share = False
1379
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
1380
else:
1381
if not (quiet):
1382
print('\u2714 Connected')
1383
self.share_url = None
1384
1385
if inbrowser:
1386
link = self.share_url if self.share else self.local_url
1387
webbrowser.open(link)
1388
1389
# Check if running in a Python notebook in which case, display inline
1390
if inline is None:
1391
inline = utils.ipython_check() and (auth is None)
1392
if inline:
1393
if auth is not None:
1394
print(
1395
"Warning: authentication is not supported inline. Please"
1396
"click the link to access the interface in a new tab."
1397
)
1398
try:
1399
from IPython.display import HTML, Javascript, display # type: ignore
1400
1401
if self.share:
1402
while not networking.url_ok(self.share_url):
1403
time.sleep(0.25)
1404
display(
1405
HTML(
1406
f'<div><iframe src="{self.share_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1407
)
1408
)
1409
elif self.is_colab:
1410
# modified from /usr/local/lib/python3.7/dist-packages/google/colab/output/_util.py within Colab environment
1411
code = """(async (port, path, width, height, cache, element) => {
1412
if (!google.colab.kernel.accessAllowed && !cache) {
1413
return;
1414
}
1415
element.appendChild(document.createTextNode(''));
1416
const url = await google.colab.kernel.proxyPort(port, {cache});
1417
1418
const external_link = document.createElement('div');
1419
external_link.innerHTML = `
1420
<div style="font-family: monospace; margin-bottom: 0.5rem">
1421
Running on <a href=${new URL(path, url).toString()} target="_blank">
1422
https://localhost:${port}${path}
1423
</a>
1424
</div>
1425
`;
1426
element.appendChild(external_link);
1427
1428
const iframe = document.createElement('iframe');
1429
iframe.src = new URL(path, url).toString();
1430
iframe.height = height;
1431
iframe.allow = "autoplay; camera; microphone; clipboard-read; clipboard-write;"
1432
iframe.width = width;
1433
iframe.style.border = 0;
1434
element.appendChild(iframe);
1435
})""" + "({port}, {path}, {width}, {height}, {cache}, window.element)".format(
1436
port=json.dumps(self.server_port),
1437
path=json.dumps("/"),
1438
width=json.dumps(self.width),
1439
height=json.dumps(self.height),
1440
cache=json.dumps(False),
1441
)
1442
1443
display(Javascript(code))
1444
else:
1445
display(
1446
HTML(
1447
f'<div><iframe src="{self.local_url}" width="{self.width}" height="{self.height}" allow="autoplay; camera; microphone; clipboard-read; clipboard-write;" frameborder="0" allowfullscreen></iframe></div>'
1448
)
1449
)
1450
except ImportError:
1451
pass
1452
1453
if getattr(self, "analytics_enabled", False):
1454
data = {
1455
"launch_method": "browser" if inbrowser else "inline",
1456
"is_google_colab": self.is_colab,
1457
"is_sharing_on": self.share,
1458
"share_url": self.share_url,
1459
"ip_address": self.ip_address,
1460
"enable_queue": self.enable_queue,
1461
"show_tips": self.show_tips,
1462
"server_name": server_name,
1463
"server_port": server_port,
1464
"is_spaces": self.is_space,
1465
"mode": self.mode,
1466
}
1467
utils.launch_analytics(data)
1468
1469
utils.show_tip(self)
1470
1471
# Block main thread if debug==True
1472
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
1473
self.block_thread()
1474
# Block main thread if running in a script to stop script from exiting
1475
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
1476
1477
if not prevent_thread_lock and not is_in_interactive_mode:
1478
self.block_thread()
1479
1480
return self.server_app, self.local_url, self.share_url
1481
1482
def integrate(
1483
self,
1484
comet_ml: comet_ml.Experiment = None,
1485
wandb: ModuleType("wandb") = None,
1486
mlflow: ModuleType("mlflow") = None,
1487
) -> None:
1488
"""
1489
A catch-all method for integrating with other libraries. This method should be run after launch()
1490
Parameters:
1491
comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
1492
wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
1493
mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
1494
"""
1495
analytics_integration = ""
1496
if comet_ml is not None:
1497
analytics_integration = "CometML"
1498
comet_ml.log_other("Created from", "Gradio")
1499
if self.share_url is not None:
1500
comet_ml.log_text("gradio: " + self.share_url)
1501
comet_ml.end()
1502
else:
1503
comet_ml.log_text("gradio: " + self.local_url)
1504
comet_ml.end()
1505
if wandb is not None:
1506
analytics_integration = "WandB"
1507
if self.share_url is not None:
1508
wandb.log(
1509
{
1510
"Gradio panel": wandb.Html(
1511
'<iframe src="'
1512
+ self.share_url
1513
+ '" width="'
1514
+ str(self.width)
1515
+ '" height="'
1516
+ str(self.height)
1517
+ '" frameBorder="0"></iframe>'
1518
)
1519
}
1520
)
1521
else:
1522
print(
1523
"The WandB integration requires you to "
1524
"`launch(share=True)` first."
1525
)
1526
if mlflow is not None:
1527
analytics_integration = "MLFlow"
1528
if self.share_url is not None:
1529
mlflow.log_param("Gradio Interface Share Link", self.share_url)
1530
else:
1531
mlflow.log_param("Gradio Interface Local Link", self.local_url)
1532
if self.analytics_enabled and analytics_integration:
1533
data = {"integration": analytics_integration}
1534
utils.integration_analytics(data)
1535
1536
def close(self, verbose: bool = True) -> None:
1537
"""
1538
Closes the Interface that was launched and frees the port.
1539
"""
1540
try:
1541
if self.enable_queue:
1542
self._queue.close()
1543
self.server.close()
1544
self.is_running = False
1545
if verbose:
1546
print("Closing server running on port: {}".format(self.server_port))
1547
except (AttributeError, OSError): # can't close if not running
1548
pass
1549
1550
def block_thread(
1551
self,
1552
) -> None:
1553
"""Block main thread until interrupted by user."""
1554
try:
1555
while True:
1556
time.sleep(0.1)
1557
except (KeyboardInterrupt, OSError):
1558
print("Keyboard interruption in main thread... closing server.")
1559
self.server.close()
1560
1561
def attach_load_events(self):
1562
"""Add a load event for every component whose initial value should be randomized."""
1563
1564
for component in Context.root_block.blocks.values():
1565
if (
1566
isinstance(component, components.IOComponent)
1567
and component.attach_load_event
1568
):
1569
# Use set_event_trigger to avoid ambiguity between load class/instance method
1570
self.set_event_trigger(
1571
"load",
1572
component.load_fn,
1573
None,
1574
component,
1575
no_target=True,
1576
queue=False,
1577
)
1578
1579
def startup_events(self):
1580
"""Events that should be run when the app containing this block starts up."""
1581
if self.enable_queue:
1582
utils.run_coro_in_background(self._queue.start)
1583
utils.run_coro_in_background(self.create_limiter)
1584
1585
def queue_enabled_for_fn(self, fn_index: int):
1586
if self.dependencies[fn_index]["queue"] is None:
1587
return self.enable_queue
1588
return self.dependencies[fn_index]["queue"]
1589
1590