Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/speech_to_image_diffusion.py
1448 views
1
import inspect
2
from typing import Callable, List, Optional, Union
3
4
import torch
5
from transformers import (
6
CLIPImageProcessor,
7
CLIPTextModel,
8
CLIPTokenizer,
9
WhisperForConditionalGeneration,
10
WhisperProcessor,
11
)
12
13
from diffusers import (
14
AutoencoderKL,
15
DDIMScheduler,
16
DiffusionPipeline,
17
LMSDiscreteScheduler,
18
PNDMScheduler,
19
UNet2DConditionModel,
20
)
21
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
22
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
23
from diffusers.utils import logging
24
25
26
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
28
29
class SpeechToImagePipeline(DiffusionPipeline):
30
def __init__(
31
self,
32
speech_model: WhisperForConditionalGeneration,
33
speech_processor: WhisperProcessor,
34
vae: AutoencoderKL,
35
text_encoder: CLIPTextModel,
36
tokenizer: CLIPTokenizer,
37
unet: UNet2DConditionModel,
38
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
39
safety_checker: StableDiffusionSafetyChecker,
40
feature_extractor: CLIPImageProcessor,
41
):
42
super().__init__()
43
44
if safety_checker is None:
45
logger.warning(
46
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
47
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
48
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
49
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
50
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
51
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
52
)
53
54
self.register_modules(
55
speech_model=speech_model,
56
speech_processor=speech_processor,
57
vae=vae,
58
text_encoder=text_encoder,
59
tokenizer=tokenizer,
60
unet=unet,
61
scheduler=scheduler,
62
feature_extractor=feature_extractor,
63
)
64
65
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66
if slice_size == "auto":
67
slice_size = self.unet.config.attention_head_dim // 2
68
self.unet.set_attention_slice(slice_size)
69
70
def disable_attention_slicing(self):
71
self.enable_attention_slicing(None)
72
73
@torch.no_grad()
74
def __call__(
75
self,
76
audio,
77
sampling_rate=16_000,
78
height: int = 512,
79
width: int = 512,
80
num_inference_steps: int = 50,
81
guidance_scale: float = 7.5,
82
negative_prompt: Optional[Union[str, List[str]]] = None,
83
num_images_per_prompt: Optional[int] = 1,
84
eta: float = 0.0,
85
generator: Optional[torch.Generator] = None,
86
latents: Optional[torch.FloatTensor] = None,
87
output_type: Optional[str] = "pil",
88
return_dict: bool = True,
89
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
90
callback_steps: int = 1,
91
**kwargs,
92
):
93
inputs = self.speech_processor.feature_extractor(
94
audio, return_tensors="pt", sampling_rate=sampling_rate
95
).input_features.to(self.device)
96
predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
97
98
prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
99
0
100
]
101
102
if isinstance(prompt, str):
103
batch_size = 1
104
elif isinstance(prompt, list):
105
batch_size = len(prompt)
106
else:
107
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
108
109
if height % 8 != 0 or width % 8 != 0:
110
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
111
112
if (callback_steps is None) or (
113
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
114
):
115
raise ValueError(
116
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
117
f" {type(callback_steps)}."
118
)
119
120
# get prompt text embeddings
121
text_inputs = self.tokenizer(
122
prompt,
123
padding="max_length",
124
max_length=self.tokenizer.model_max_length,
125
return_tensors="pt",
126
)
127
text_input_ids = text_inputs.input_ids
128
129
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
130
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
131
logger.warning(
132
"The following part of your input was truncated because CLIP can only handle sequences up to"
133
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
134
)
135
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
136
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
137
138
# duplicate text embeddings for each generation per prompt, using mps friendly method
139
bs_embed, seq_len, _ = text_embeddings.shape
140
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
141
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
142
143
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
144
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
145
# corresponds to doing no classifier free guidance.
146
do_classifier_free_guidance = guidance_scale > 1.0
147
# get unconditional embeddings for classifier free guidance
148
if do_classifier_free_guidance:
149
uncond_tokens: List[str]
150
if negative_prompt is None:
151
uncond_tokens = [""] * batch_size
152
elif type(prompt) is not type(negative_prompt):
153
raise TypeError(
154
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
155
f" {type(prompt)}."
156
)
157
elif isinstance(negative_prompt, str):
158
uncond_tokens = [negative_prompt]
159
elif batch_size != len(negative_prompt):
160
raise ValueError(
161
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
162
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
163
" the batch size of `prompt`."
164
)
165
else:
166
uncond_tokens = negative_prompt
167
168
max_length = text_input_ids.shape[-1]
169
uncond_input = self.tokenizer(
170
uncond_tokens,
171
padding="max_length",
172
max_length=max_length,
173
truncation=True,
174
return_tensors="pt",
175
)
176
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
177
178
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
179
seq_len = uncond_embeddings.shape[1]
180
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
181
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
182
183
# For classifier free guidance, we need to do two forward passes.
184
# Here we concatenate the unconditional and text embeddings into a single batch
185
# to avoid doing two forward passes
186
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
187
188
# get the initial random noise unless the user supplied it
189
190
# Unlike in other pipelines, latents need to be generated in the target device
191
# for 1-to-1 results reproducibility with the CompVis implementation.
192
# However this currently doesn't work in `mps`.
193
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
194
latents_dtype = text_embeddings.dtype
195
if latents is None:
196
if self.device.type == "mps":
197
# randn does not exist on mps
198
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
199
self.device
200
)
201
else:
202
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
203
else:
204
if latents.shape != latents_shape:
205
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
206
latents = latents.to(self.device)
207
208
# set timesteps
209
self.scheduler.set_timesteps(num_inference_steps)
210
211
# Some schedulers like PNDM have timesteps as arrays
212
# It's more optimized to move all timesteps to correct device beforehand
213
timesteps_tensor = self.scheduler.timesteps.to(self.device)
214
215
# scale the initial noise by the standard deviation required by the scheduler
216
latents = latents * self.scheduler.init_noise_sigma
217
218
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
219
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
220
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
221
# and should be between [0, 1]
222
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
223
extra_step_kwargs = {}
224
if accepts_eta:
225
extra_step_kwargs["eta"] = eta
226
227
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
228
# expand the latents if we are doing classifier free guidance
229
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
230
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
231
232
# predict the noise residual
233
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
234
235
# perform guidance
236
if do_classifier_free_guidance:
237
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
238
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
239
240
# compute the previous noisy sample x_t -> x_t-1
241
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
242
243
# call the callback, if provided
244
if callback is not None and i % callback_steps == 0:
245
callback(i, t, latents)
246
247
latents = 1 / 0.18215 * latents
248
image = self.vae.decode(latents).sample
249
250
image = (image / 2 + 0.5).clamp(0, 1)
251
252
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
253
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
254
255
if output_type == "pil":
256
image = self.numpy_to_pil(image)
257
258
if not return_dict:
259
return image
260
261
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
262
263