Path: blob/main/examples/research_projects/intel_opts/inference_bf16.py
1979 views
import intel_extension_for_pytorch as ipex1import torch2from PIL import Image34from diffusers import StableDiffusionPipeline567def image_grid(imgs, rows, cols):8assert len(imgs) == rows * cols910w, h = imgs[0].size11grid = Image.new("RGB", size=(cols * w, rows * h))12grid_w, grid_h = grid.size1314for i, img in enumerate(imgs):15grid.paste(img, box=(i % cols * w, i // cols * h))16return grid171819prompt = ["a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"]20batch_size = 821prompt = prompt * batch_size2223device = "cpu"24model_id = "path-to-your-trained-model"25model = StableDiffusionPipeline.from_pretrained(model_id)26model = model.to(device)2728# to channels last29model.unet = model.unet.to(memory_format=torch.channels_last)30model.vae = model.vae.to(memory_format=torch.channels_last)31model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last)32model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last)3334# optimize with ipex35model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True)36model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True)37model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)38model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)3940# compute41seed = 66642generator = torch.Generator(device).manual_seed(seed)43with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):44images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images4546# save image47grid = image_grid(images, rows=2, cols=4)48grid.save(model_id + ".png")495051