Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/__init__.py
694 views
1
import glob
2
import logging
3
import re
4
import time
5
from collections import defaultdict
6
import os
7
import sys
8
import shutil
9
import types
10
import numpy as np
11
import torch
12
import torch.nn.functional as F
13
import torch.distributed as dist
14
from torch import nn
15
16
17
def tensors_to_scalars(metrics):
18
new_metrics = {}
19
for k, v in metrics.items():
20
if isinstance(v, torch.Tensor):
21
v = v.item()
22
if type(v) is dict:
23
v = tensors_to_scalars(v)
24
new_metrics[k] = v
25
return new_metrics
26
27
28
class AvgrageMeter(object):
29
30
def __init__(self):
31
self.reset()
32
33
def reset(self):
34
self.avg = 0
35
self.sum = 0
36
self.cnt = 0
37
38
def update(self, val, n=1):
39
self.sum += val * n
40
self.cnt += n
41
self.avg = self.sum / self.cnt
42
43
44
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
45
"""Convert a list of 1d tensors into a padded 2d tensor."""
46
size = max(v.size(0) for v in values) if max_len is None else max_len
47
res = values[0].new(len(values), size).fill_(pad_idx)
48
49
def copy_tensor(src, dst):
50
assert dst.numel() == src.numel()
51
if shift_right:
52
dst[1:] = src[:-1]
53
dst[0] = shift_id
54
else:
55
dst.copy_(src)
56
57
for i, v in enumerate(values):
58
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
59
return res
60
61
62
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
63
"""Convert a list of 2d tensors into a padded 3d tensor."""
64
size = max(v.size(0) for v in values) if max_len is None else max_len
65
res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
66
67
def copy_tensor(src, dst):
68
assert dst.numel() == src.numel()
69
if shift_right:
70
dst[1:] = src[:-1]
71
else:
72
dst.copy_(src)
73
74
for i, v in enumerate(values):
75
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
76
return res
77
78
79
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
80
if len(batch) == 0:
81
return 0
82
if len(batch) == max_sentences:
83
return 1
84
if num_tokens > max_tokens:
85
return 1
86
return 0
87
88
89
def batch_by_size(
90
indices, num_tokens_fn, max_tokens=None, max_sentences=None,
91
required_batch_size_multiple=1, distributed=False
92
):
93
"""
94
Yield mini-batches of indices bucketed by size. Batches may contain
95
sequences of different lengths.
96
97
Args:
98
indices (List[int]): ordered list of dataset indices
99
num_tokens_fn (callable): function that returns the number of tokens at
100
a given index
101
max_tokens (int, optional): max number of tokens in each batch
102
(default: None).
103
max_sentences (int, optional): max number of sentences in each
104
batch (default: None).
105
required_batch_size_multiple (int, optional): require batch size to
106
be a multiple of N (default: 1).
107
"""
108
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
109
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
110
bsz_mult = required_batch_size_multiple
111
112
if isinstance(indices, types.GeneratorType):
113
indices = np.fromiter(indices, dtype=np.int64, count=-1)
114
115
sample_len = 0
116
sample_lens = []
117
batch = []
118
batches = []
119
for i in range(len(indices)):
120
idx = indices[i]
121
num_tokens = num_tokens_fn(idx)
122
sample_lens.append(num_tokens)
123
sample_len = max(sample_len, num_tokens)
124
assert sample_len <= max_tokens, (
125
"sentence at index {} of size {} exceeds max_tokens "
126
"limit of {}!".format(idx, sample_len, max_tokens)
127
)
128
num_tokens = (len(batch) + 1) * sample_len
129
130
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
131
mod_len = max(
132
bsz_mult * (len(batch) // bsz_mult),
133
len(batch) % bsz_mult,
134
)
135
batches.append(batch[:mod_len])
136
batch = batch[mod_len:]
137
sample_lens = sample_lens[mod_len:]
138
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
139
batch.append(idx)
140
if len(batch) > 0:
141
batches.append(batch)
142
return batches
143
144
145
def make_positions(tensor, padding_idx):
146
"""Replace non-padding symbols with their position numbers.
147
148
Position numbers begin at padding_idx+1. Padding symbols are ignored.
149
"""
150
# The series of casts and type-conversions here are carefully
151
# balanced to both work with ONNX export and XLA. In particular XLA
152
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
153
# how to handle the dtype kwarg in cumsum.
154
mask = tensor.ne(padding_idx).int()
155
return (
156
torch.cumsum(mask, dim=1).type_as(mask) * mask
157
).long() + padding_idx
158
159
160
def softmax(x, dim):
161
return F.softmax(x, dim=dim, dtype=torch.float32)
162
163
164
def unpack_dict_to_list(samples):
165
samples_ = []
166
bsz = samples.get('outputs').size(0)
167
for i in range(bsz):
168
res = {}
169
for k, v in samples.items():
170
try:
171
res[k] = v[i]
172
except:
173
pass
174
samples_.append(res)
175
return samples_
176
177
178
def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True):
179
if os.path.isfile(ckpt_base_dir):
180
base_dir = os.path.dirname(ckpt_base_dir)
181
checkpoint_path = [ckpt_base_dir]
182
else:
183
base_dir = ckpt_base_dir
184
checkpoint_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
185
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))
186
if len(checkpoint_path) > 0:
187
checkpoint_path = checkpoint_path[-1]
188
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
189
state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
190
if k.startswith(f'{prefix_in_ckpt}.')}
191
if not strict:
192
cur_model_state_dict = cur_model.state_dict()
193
unmatched_keys = []
194
for key, param in state_dict.items():
195
if key in cur_model_state_dict:
196
new_param = cur_model_state_dict[key]
197
if new_param.shape != param.shape:
198
unmatched_keys.append(key)
199
print("| Unmatched keys: ", key, new_param.shape, param.shape)
200
for key in unmatched_keys:
201
del state_dict[key]
202
cur_model.load_state_dict(state_dict, strict=strict)
203
print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.")
204
else:
205
e_msg = f"| ckpt not found in {base_dir}."
206
if force:
207
assert False, e_msg
208
else:
209
print(e_msg)
210
211
212
def remove_padding(x, padding_idx=0):
213
if x is None:
214
return None
215
assert len(x.shape) in [1, 2]
216
if len(x.shape) == 2: # [T, H]
217
return x[np.abs(x).sum(-1) != padding_idx]
218
elif len(x.shape) == 1: # [T]
219
return x[x != padding_idx]
220
221
222
class Timer:
223
timer_map = {}
224
225
def __init__(self, name, print_time=False):
226
if name not in Timer.timer_map:
227
Timer.timer_map[name] = 0
228
self.name = name
229
self.print_time = print_time
230
231
def __enter__(self):
232
self.t = time.time()
233
234
def __exit__(self, exc_type, exc_val, exc_tb):
235
Timer.timer_map[self.name] += time.time() - self.t
236
if self.print_time:
237
print(self.name, Timer.timer_map[self.name])
238
239
240
def print_arch(model, model_name='model'):
241
#print(f"| {model_name} Arch: ", model)
242
num_params(model, model_name=model_name)
243
244
245
def num_params(model, print_out=True, model_name="model"):
246
parameters = filter(lambda p: p.requires_grad, model.parameters())
247
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
248
if print_out:
249
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
250
return parameters
251
252