Path: blob/main/examples/community/one_step_unet.py
1448 views
#!/usr/bin/env python31import torch23from diffusers import DiffusionPipeline456class UnetSchedulerOneForwardPipeline(DiffusionPipeline):7def __init__(self, unet, scheduler):8super().__init__()910self.register_modules(unet=unet, scheduler=scheduler)1112def __call__(self):13image = torch.randn(14(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),15)16timestep = 11718model_output = self.unet(image, timestep).sample19scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample2021result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)2223return result242526