Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_models_diffuser_to_diffusers.py
1440 views
1
import json
2
import os
3
4
import torch
5
6
from diffusers import UNet1DModel
7
8
9
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
10
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
11
12
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
13
14
15
def unet(hor):
16
if hor == 128:
17
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
18
block_out_channels = (32, 128, 256)
19
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
20
21
elif hor == 32:
22
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
23
block_out_channels = (32, 64, 128, 256)
24
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
25
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
26
state_dict = model.state_dict()
27
config = dict(
28
down_block_types=down_block_types,
29
block_out_channels=block_out_channels,
30
up_block_types=up_block_types,
31
layers_per_block=1,
32
use_timestep_embedding=True,
33
out_block_type="OutConv1DBlock",
34
norm_num_groups=8,
35
downsample_each_block=False,
36
in_channels=14,
37
out_channels=14,
38
extra_in_channels=0,
39
time_embedding_type="positional",
40
flip_sin_to_cos=False,
41
freq_shift=1,
42
sample_size=65536,
43
mid_block_type="MidResTemporalBlock1D",
44
act_fn="mish",
45
)
46
hf_value_function = UNet1DModel(**config)
47
print(f"length of state dict: {len(state_dict.keys())}")
48
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
49
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
50
for k, v in mapping.items():
51
state_dict[v] = state_dict.pop(k)
52
hf_value_function.load_state_dict(state_dict)
53
54
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
55
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
56
json.dump(config, f)
57
58
59
def value_function():
60
config = dict(
61
in_channels=14,
62
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63
up_block_types=(),
64
out_block_type="ValueFunction",
65
mid_block_type="ValueFunctionMidBlock1D",
66
block_out_channels=(32, 64, 128, 256),
67
layers_per_block=1,
68
downsample_each_block=True,
69
sample_size=65536,
70
out_channels=14,
71
extra_in_channels=0,
72
time_embedding_type="positional",
73
use_timestep_embedding=True,
74
flip_sin_to_cos=False,
75
freq_shift=1,
76
norm_num_groups=8,
77
act_fn="mish",
78
)
79
80
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
81
state_dict = model
82
hf_value_function = UNet1DModel(**config)
83
print(f"length of state dict: {len(state_dict.keys())}")
84
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
85
86
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
87
for k, v in mapping.items():
88
state_dict[v] = state_dict.pop(k)
89
90
hf_value_function.load_state_dict(state_dict)
91
92
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
93
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
94
json.dump(config, f)
95
96
97
if __name__ == "__main__":
98
unet(32)
99
# unet(128)
100
value_function()
101
102