CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/diffusers/prompt_scheduling_callback.ipynb
Views: 2535
Kernel: Unknown Kernel

Prompt scheduling callback

Prompt scheduling callback allows changing prompts during a generation, like prompt editing in A1111.

pip install diffusers torch accelerate transformers
Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.31.0) Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.2.1+cu121) Requirement already satisfied: accelerate in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (1.1.1) Requirement already satisfied: transformers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (4.46.2) Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.5.0) Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.16.1) Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.26.2) Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4) Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6) Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2.32.3) Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.4.5) Requirement already satisfied: Pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (11.0.0) Requirement already satisfied: typing-extensions>=4.8.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2) Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.3) Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2) Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.4) Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2024.10.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105) Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (8.9.2.26) Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.3.1) Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.0.2.54) Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (10.3.2.106) Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.4.5.107) Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.0.106) Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.19.3) Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105) Requirement already satisfied: triton==2.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.2.0) Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.77) Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (24.1) Requirement already satisfied: psutil in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.1.0) Requirement already satisfied: pyyaml in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.0.2) Requirement already satisfied: tokenizers<0.21,>=0.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.20.3) Requirement already satisfied: tqdm>=4.27 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (4.66.6) Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0) Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2) Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.4.0) Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2024.8.30) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch) (1.3.0) Note: you may need to restart the kernel to use updated packages.
from diffusers import StableDiffusionPipeline from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks from diffusers.configuration_utils import register_to_config import torch from typing import Any, Dict, Optional pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ).to("cuda") pipeline.safety_checker = None pipeline.requires_safety_checker = False class SDPromptScheduleCallback(PipelineCallback): @register_to_config def __init__( self, prompt: str, negative_prompt: Optional[str] = None, num_images_per_prompt: int = 1, cutoff_step_ratio=1.0, cutoff_step_index=None, ): super().__init__( cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index ) tensor_inputs = ["prompt_embeds"] def callback_fn( self, pipeline, step_index, timestep, callback_kwargs ) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) ) if step_index == cutoff_step: prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( prompt=self.config.prompt, negative_prompt=self.config.negative_prompt, device=pipeline._execution_device, num_images_per_prompt=self.config.num_images_per_prompt, do_classifier_free_guidance=pipeline.do_classifier_free_guidance, ) if pipeline.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) callback_kwargs[self.tensor_inputs[0]] = prompt_embeds return callback_kwargs
callback = MultiPipelineCallbacks( [ SDPromptScheduleCallback( prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", negative_prompt="Deformed, ugly, bad anatomy", cutoff_step_ratio=0.25, ) ] ) image = pipeline( prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", negative_prompt="Deformed, ugly, bad anatomy", callback_on_step_end=callback, callback_on_step_end_tensor_inputs=["prompt_embeds"], ).images[0] torch.cuda.empty_cache() image.save('image2.png')