Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/diffusers/stable_diffusion_textual_inversion_library_navigator.ipynb
Views: 2535
Kernel: Python 3
Stable Diffusion Textual Inversion - Concept Library navigation and usage
Navigate through the public library of concepts and use Stable Diffusion with custom concepts. 🤗 Hugging Face 🧨 Diffusers library.
By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images
If you would like to teach Stable Diffusion your own concepts, check out the training notebook
Initial setup
In [ ]:
#@title Install the required libs !pip install -qq diffusers==0.4.1 transformers ftfy gradio wget
In [ ]:
#@title Login to the Hugging Face Hub #@markdown If you haven't yet, [you have to first acknowledge and agree to the model LICENSE before using it](https://huggingface.co/CompVis/stable-diffusion-v1-4) from huggingface_hub import notebook_login notebook_login()
In [ ]:
#@title Prepare the Concepts Library to be used import requests import os import gradio as gr import wget import torch from diffusers import StableDiffusionPipeline from huggingface_hub import HfApi from transformers import CLIPTextModel, CLIPTokenizer from tqdm.notebook import tqdm api = HfApi() models_list = api.list_models(author="sd-concepts-library") models = [] pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16).to("cuda") def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") # separate token and the embeds trained_token = list(loaded_learned_embeds.keys())[0] embeds = loaded_learned_embeds[trained_token] # cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype embeds.to(dtype) # add the token in tokenizer token = token if token is not None else trained_token num_added_tokens = tokenizer.add_tokens(token) i = 1 while(num_added_tokens == 0): print(f"The tokenizer already contains the token {token}.") token = f"{token[:-1]}-{i}>" print(f"Attempting to add the token {token}.") num_added_tokens = tokenizer.add_tokens(token) i+=1 # resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embeds return token print("Setting up the public library") for model in tqdm(models_list): model_content = {} model_id = model.modelId model_content["id"] = model_id embeds_url = f"https://huggingface.co/{model_id}/resolve/main/learned_embeds.bin" os.makedirs(model_id,exist_ok = True) if not os.path.exists(f"{model_id}/learned_embeds.bin"): try: wget.download(embeds_url, out=model_id) except: continue token_identifier = f"https://huggingface.co/{model_id}/raw/main/token_identifier.txt" response = requests.get(token_identifier) token_name = response.text concept_type = f"https://huggingface.co/{model_id}/raw/main/type_of_concept.txt" response = requests.get(concept_type) concept_name = response.text model_content["concept_type"] = concept_name images = [] for i in range(4): url = f"https://huggingface.co/{model_id}/resolve/main/concept_images/{i}.jpeg" image_download = requests.get(url) url_code = image_download.status_code if(url_code == 200): file = open(f"{model_id}/{i}.jpeg", "wb") ## Creates the file for image file.write(image_download.content) ## Saves file content file.close() images.append(f"{model_id}/{i}.jpeg") model_content["images"] = images learned_token = load_learned_embed_in_clip(f"{model_id}/learned_embeds.bin", pipe.text_encoder, pipe.tokenizer, token_name) model_content["token"] = learned_token models.append(model_content)
Go!
In [ ]:
#@title Run the app to navigate around [the Library](https://huggingface.co/sd-concepts-library) #@markdown Click the `Running on public URL:` result to run the Gradio app SELECT_LABEL = "Select concept" def title_block(title, id): return gr.Markdown(f"### [`{title}`](https://huggingface.co/{id})") def image_block(image_list, concept_type): return gr.Gallery( label=concept_type, value=image_list, elem_id="gallery" ).style(grid=[2], height="auto") def checkbox_block(): checkbox = gr.Checkbox(label=SELECT_LABEL).style(container=False) return checkbox def infer(text): images_list = pipe( text, num_images_per_prompt=2, num_inference_steps=50, guidance_scale=7.5 ) output_images = [] for i, image in enumerate(images_list["sample"]): output_images.append(image) return output_images css = ''' .gradio-container {font-family: 'IBM Plex Sans', sans-serif} #top_title{margin-bottom: .5em} #top_title h2{margin-bottom: 0; text-align: center} #main_row{flex-wrap: wrap; gap: 1em; max-height: calc(100vh - 16em); overflow-y: scroll; flex-direction: row} @media (min-width: 768px){#main_row > div{flex: 1 1 32%; margin-left: 0 !important}} .gr-prose code::before, .gr-prose code::after {content: "" !important} ::-webkit-scrollbar {width: 10px} ::-webkit-scrollbar-track {background: #f1f1f1} ::-webkit-scrollbar-thumb {background: #888} ::-webkit-scrollbar-thumb:hover {background: #555} .gr-button {white-space: nowrap} .gr-button:focus { border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); --tw-ring-opacity: .5; } #prompt_input{flex: 1 3 auto} #prompt_area{margin-bottom: .75em} #prompt_area > div:first-child{flex: 1 3 auto} ''' examples = ["a <cat-toy> in <madhubani-art> style", "a mecha robot in <line-art> style", "a piano being played by <bonzi>"] with gr.Blocks(css=css) as demo: state = gr.Variable({ 'selected': -1 }) state = {} def update_state(i): global checkbox_states if(checkbox_states[i]): checkbox_states[i] = False state[i] = False else: state[i] = True checkbox_states[i] = True gr.HTML(''' <div style="text-align: center; max-width: 720px; margin: 0 auto;"> <div style=" display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; " > <svg width="0.65em" height="0.65em" viewBox="0 0 115 115" fill="none" xmlns="http://www.w3.org/2000/svg" > <rect width="23" height="23" fill="white"></rect> <rect y="69" width="23" height="23" fill="white"></rect> <rect x="23" width="23" height="23" fill="#AEAEAE"></rect> <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect> <rect x="46" width="23" height="23" fill="white"></rect> <rect x="46" y="69" width="23" height="23" fill="white"></rect> <rect x="69" width="23" height="23" fill="black"></rect> <rect x="69" y="69" width="23" height="23" fill="black"></rect> <rect x="92" width="23" height="23" fill="#D9D9D9"></rect> <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect> <rect x="115" y="46" width="23" height="23" fill="white"></rect> <rect x="115" y="115" width="23" height="23" fill="white"></rect> <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect> <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect> <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect> <rect x="92" y="69" width="23" height="23" fill="white"></rect> <rect x="69" y="46" width="23" height="23" fill="white"></rect> <rect x="69" y="115" width="23" height="23" fill="white"></rect> <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect> <rect x="46" y="46" width="23" height="23" fill="black"></rect> <rect x="46" y="115" width="23" height="23" fill="black"></rect> <rect x="46" y="69" width="23" height="23" fill="black"></rect> <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect> <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect> <rect x="23" y="69" width="23" height="23" fill="black"></rect> </svg> <h1 style="font-weight: 900; margin-bottom: 7px;"> Stable Diffusion Conceptualizer </h1> </div> <p style="margin-bottom: 10px; font-size: 94%"> Navigate through community created concepts and styles via Stable Diffusion Textual Inversion and pick yours for inference. To train your own concepts and contribute to the library <a style="text-decoration: underline" href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb">check out this notebook</a>. </p> </div> ''') with gr.Row(): with gr.Column(): gr.Markdown(''' ### Textual-Inversion trained [concepts library](https://huggingface.co/sd-concepts-library) navigator ''') with gr.Row(elem_id="main_row"): image_blocks = [] for i, model in enumerate(models): with gr.Box().style(border=None): title_block(model["token"], model["id"]) image_blocks.append(image_block(model["images"], model["concept_type"])) with gr.Box(): with gr.Row(elem_id="prompt_area").style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", placeholder="Enter your prompt", show_label=False, max_lines=1, elem_id="prompt_input" ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False ) btn = gr.Button("Run",elem_id="run_btn").style( margin=False, rounded=(False, True, True, False) ) with gr.Row().style(): infer_outputs = gr.Gallery(show_label=False).style(grid=[2], height="512px") with gr.Row(): gr.HTML("<p style=\"font-size: 85%;margin-top: .75em\">Prompting may not work as you are used to; <code>objects</code> may need the concept added at the end.</p>") with gr.Row(): gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=infer_outputs, cache_examples=False) checkbox_states = {} inputs = [text] btn.click( infer, inputs=inputs, outputs=infer_outputs ) demo.launch(inline=False, debug=True)