Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/torch_lazy_loader.py
471 views
1
'''
2
This file is AGPL-licensed.
3
4
Some of the code in this file is copied from PyTorch.
5
6
The license for PyTorch is shown below:
7
8
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
9
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
10
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
11
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
12
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
13
Copyright (c) 2011-2013 NYU (Clement Farabet)
14
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
15
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
16
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
17
18
Redistribution and use in source and binary forms, with or without
19
modification, are permitted provided that the following conditions are met:
20
21
1. Redistributions of source code must retain the above copyright
22
notice, this list of conditions and the following disclaimer.
23
24
2. Redistributions in binary form must reproduce the above copyright
25
notice, this list of conditions and the following disclaimer in the
26
documentation and/or other materials provided with the distribution.
27
28
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
29
and IDIAP Research Institute nor the names of its contributors may be
30
used to endorse or promote products derived from this software without
31
specific prior written permission.
32
33
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
34
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
35
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
36
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
37
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
38
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
39
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
40
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
41
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
42
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
43
POSSIBILITY OF SUCH DAMAGE.
44
'''
45
46
47
import contextlib
48
from functools import reduce
49
import itertools
50
import zipfile
51
import pickle
52
import torch
53
import numpy as np
54
import collections
55
import _codecs
56
import utils
57
import os
58
from torch.nn import Module
59
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
60
61
62
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
63
64
65
STORAGE_TYPE_MAP = {
66
torch.float64: torch.DoubleStorage,
67
torch.float32: torch.FloatStorage,
68
torch.float16: torch.HalfStorage,
69
torch.int64: torch.LongStorage,
70
torch.int32: torch.IntStorage,
71
torch.int16: torch.ShortStorage,
72
torch.int8: torch.CharStorage,
73
torch.uint8: torch.ByteStorage,
74
torch.bool: torch.BoolStorage,
75
torch.bfloat16: torch.BFloat16Storage,
76
}
77
78
79
class LazyTensor:
80
def __init__(self, storage_type, key: str, location: str, dtype: Optional[torch.dtype] = None, seek_offset: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, stride: Optional[Tuple[int, ...]] = None, requires_grad=False, backward_hooks: Any = None):
81
self.storage_type = storage_type
82
self.key = key
83
self.location = location
84
self.dtype = dtype
85
self.seek_offset = seek_offset
86
self.shape = shape
87
self.stride = stride
88
self.requires_grad = requires_grad
89
self.backward_hooks = backward_hooks
90
91
def __view(self, f: Callable):
92
return f"{type(self).__name__}(storage_type={f(self.storage_type)}, key={f(self.key)}, location={f(self.location)}, dtype={f(self.dtype)}, seek_offset={f(self.seek_offset)}, shape={f(self.shape)}, stride={f(self.stride)}, requires_grad={f(self.requires_grad)}, backward_hooks={f(self.backward_hooks)})"
93
94
def __repr__(self):
95
return self.__view(repr)
96
97
def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], map_location=None, no_grad=True, filename="pytorch_model.bin") -> torch.Tensor:
98
filename = os.path.basename(os.path.normpath(filename)).split('.')[0]
99
size = reduce(lambda x, y: x * y, self.shape, 1)
100
dtype = self.dtype
101
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
102
if isinstance(checkpoint, zipfile.ZipFile):
103
try:
104
f = checkpoint.open(f"archive/data/{self.key}", "r")
105
except:
106
f = checkpoint.open(f"{filename}/data/{self.key}", "r")
107
f.read(self.seek_offset)
108
else:
109
f = checkpoint
110
try:
111
storage = STORAGE_TYPE_MAP[dtype].from_buffer(f.read(nbytes), "little")
112
finally:
113
if isinstance(checkpoint, zipfile.ZipFile):
114
f.close()
115
storage = torch.serialization._get_restore_location(map_location)(storage, self.location)
116
tensor = torch.tensor([], dtype=storage.dtype, device=storage.device)
117
tensor.set_(storage, 0, self.shape, self.stride)
118
tensor.requires_grad = not no_grad and self.requires_grad
119
tensor._backward_hooks = self.backward_hooks
120
return tensor
121
122
class RestrictedUnpickler(pickle.Unpickler):
123
def original_persistent_load(self, saved_id):
124
return super().persistent_load(saved_id)
125
126
def forced_persistent_load(self, saved_id):
127
if saved_id[0] != "storage":
128
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
129
return self.original_persistent_load(saved_id)
130
131
def find_class(self, module, name):
132
if module == "collections" and name == "OrderedDict":
133
return collections.OrderedDict
134
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
135
return torch._utils._rebuild_tensor_v2
136
elif module == "torch" and name in (
137
"DoubleStorage",
138
"FloatStorage",
139
"HalfStorage",
140
"LongStorage",
141
"IntStorage",
142
"ShortStorage",
143
"CharStorage",
144
"ByteStorage",
145
"BoolStorage",
146
"BFloat16Storage",
147
):
148
return getattr(torch, name)
149
elif module == "numpy.core.multiarray" and name == "scalar":
150
return np.core.multiarray.scalar
151
elif module == "numpy" and name == "dtype":
152
return np.dtype
153
elif module == "_codecs" and name == "encode":
154
return _codecs.encode
155
else:
156
# Forbid everything else.
157
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
158
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
159
160
def load(self, *args, **kwargs):
161
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
162
self.persistent_load = self.forced_persistent_load
163
return super().load(*args, **kwargs)
164
165
class _LazyUnpickler(RestrictedUnpickler):
166
lazy_loaded_storages: Dict[str, LazyTensor]
167
168
def __init__(self, *args, **kwargs):
169
self.lazy_loaded_storages = {}
170
return super().__init__(*args, **kwargs)
171
172
def forced_persistent_load(self, saved_id):
173
assert isinstance(saved_id, tuple)
174
typename = saved_id[0]
175
assert typename == "storage", f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
176
storage_type, key, location, _ = saved_id[1:]
177
return LazyTensor(storage_type, key, location)
178
179
def load(self, *args, **kwargs):
180
retval = super().load(*args, **kwargs)
181
self.lazy_loaded_storages = {}
182
return retval
183
184
185
def _rebuild_tensor(lazy_storage: LazyTensor, storage_offset, shape, stride):
186
lazy_storage.shape = shape
187
lazy_storage.stride = stride
188
dtype = lazy_storage.storage_type.dtype
189
if not isinstance(dtype, torch.dtype):
190
dtype = lazy_storage.storage_type(0).dtype
191
lazy_storage.dtype = dtype
192
lazy_storage.seek_offset = storage_offset if dtype is torch.bool else storage_offset * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
193
return lazy_storage
194
195
196
# Modified version of https://github.com/pytorch/pytorch/blob/v1.11.0-rc4/torch/nn/modules/module.py#L1346-L1438
197
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
198
for hook in self._load_state_dict_pre_hooks.values():
199
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
200
201
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
202
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
203
local_state = {k: v for k, v in local_name_params if v is not None}
204
205
for name, param in local_state.items():
206
key = prefix + name
207
if key in state_dict:
208
input_param = state_dict[key]
209
if not torch.overrides.is_tensor_like(input_param):
210
error_msgs.append('While copying the parameter named "{}", '
211
'expected torch.Tensor or Tensor-like object from checkpoint but '
212
'received {}'
213
.format(key, type(input_param)))
214
continue
215
216
# This is used to avoid copying uninitialized parameters into
217
# non-lazy modules, since they dont have the hook to do the checks
218
# in such case, it will error when accessing the .shape attribute.
219
is_param_lazy = torch.nn.parameter.is_lazy(param)
220
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
221
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
222
input_param = input_param[0]
223
224
if not is_param_lazy and input_param.shape != param.shape:
225
# local shape should match the one in checkpoint
226
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
227
'the shape in current model is {}.'
228
.format(key, input_param.shape, param.shape))
229
continue
230
try:
231
with torch.no_grad():
232
#param.copy_(input_param)
233
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
234
if name in self._parameters: # This line is new
235
self._parameters[name] = new_param # This line is new
236
if name in persistent_buffers: # This line is new
237
self._buffers[name] = new_param # This line is new
238
except Exception as ex:
239
error_msgs.append('While copying the parameter named "{}", '
240
'whose dimensions in the model are {} and '
241
'whose dimensions in the checkpoint are {}, '
242
'an exception occurred : {}.'
243
.format(key, param.size(), input_param.size(), ex.args))
244
elif strict:
245
missing_keys.append(key)
246
247
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
248
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
249
if extra_state_key in state_dict:
250
self.set_extra_state(state_dict[extra_state_key])
251
elif strict:
252
missing_keys.append(extra_state_key)
253
elif strict and (extra_state_key in state_dict):
254
unexpected_keys.append(extra_state_key)
255
256
if strict:
257
for key in state_dict.keys():
258
if key.startswith(prefix) and key != extra_state_key:
259
input_name = key[len(prefix):]
260
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
261
if input_name not in self._modules and input_name not in local_state:
262
unexpected_keys.append(key)
263
264
265
@contextlib.contextmanager
266
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
267
try:
268
old_unpickler = pickle.Unpickler
269
pickle.Unpickler = unpickler
270
271
old_pickle_load = pickle.load
272
273
def new_pickle_load(*args, **kwargs):
274
return pickle.Unpickler(*args, **kwargs).load()
275
276
pickle.load = new_pickle_load
277
278
yield
279
280
finally:
281
pickle.Unpickler = old_unpickler
282
pickle.load = old_pickle_load
283
284
@contextlib.contextmanager
285
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
286
if not enable:
287
with use_custom_unpickler(RestrictedUnpickler):
288
yield False
289
return
290
291
try:
292
old_rebuild_tensor = torch._utils._rebuild_tensor
293
torch._utils._rebuild_tensor = _rebuild_tensor
294
295
old_torch_load = torch.load
296
297
def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
298
retval = old_torch_load(f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
299
if callback is not None:
300
callback(retval, f=f, map_location=map_location, pickle_module=pickle_module, **pickle_load_args)
301
return retval
302
303
torch.load = torch_load
304
305
if dematerialized_modules:
306
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
307
import accelerate
308
init_empty_weights = accelerate.init_empty_weights()
309
init_empty_weights.__enter__()
310
else:
311
old_linear_init = torch.nn.Linear.__init__
312
old_embedding_init = torch.nn.Embedding.__init__
313
old_layernorm_init = torch.nn.LayerNorm.__init__
314
315
def linear_init(self, *args, device=None, **kwargs):
316
return old_linear_init(self, *args, device="meta", **kwargs)
317
318
def embedding_init(self, *args, device=None, **kwargs):
319
return old_embedding_init(self, *args, device="meta", **kwargs)
320
321
def layernorm_init(self, *args, device=None, **kwargs):
322
return old_layernorm_init(self, *args, device="meta", **kwargs)
323
324
torch.nn.Linear.__init__ = linear_init
325
torch.nn.Embedding.__init__ = embedding_init
326
torch.nn.LayerNorm.__init__ = layernorm_init
327
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
328
torch.nn.Module._load_from_state_dict = _load_from_state_dict
329
330
with use_custom_unpickler(_LazyUnpickler):
331
yield True
332
333
finally:
334
torch._utils._rebuild_tensor = old_rebuild_tensor
335
torch.load = old_torch_load
336
if dematerialized_modules:
337
if use_accelerate_init_empty_weights and utils.HAS_ACCELERATE:
338
init_empty_weights.__exit__(None, None, None)
339
else:
340
torch.nn.Linear.__init__ = old_linear_init
341
torch.nn.Embedding.__init__ = old_embedding_init
342
torch.nn.LayerNorm.__init__ = old_layernorm_init
343
torch.nn.Module._load_from_state_dict = old_load_from_state_dict
344
345