Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/one_step_unet.py
1448 views
1
#!/usr/bin/env python3
2
import torch
3
4
from diffusers import DiffusionPipeline
5
6
7
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
8
def __init__(self, unet, scheduler):
9
super().__init__()
10
11
self.register_modules(unet=unet, scheduler=scheduler)
12
13
def __call__(self):
14
image = torch.randn(
15
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
16
)
17
timestep = 1
18
19
model_output = self.unet(image, timestep).sample
20
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
21
22
result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
23
24
return result
25
26