Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/breakmodel.py
471 views
1
'''
2
This is a MODIFIED version of arrmansa's low VRAM patch.
3
https://github.com/arrmansa/Basic-UI-for-GPT-J-6B-with-low-vram/blob/main/GPT-J-6B-Low-Vram-UI.ipynb
4
The ORIGINAL version of the patch is released under the Apache License 2.0
5
Copyright 2021 arrmansa
6
Copyright 2021 finetuneanon
7
Copyright 2018, 2022 The Hugging Face team
8
9
10
Apache License
11
Version 2.0, January 2004
12
http://www.apache.org/licenses/
13
14
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
15
16
1. Definitions.
17
18
"License" shall mean the terms and conditions for use, reproduction,
19
and distribution as defined by Sections 1 through 9 of this document.
20
21
"Licensor" shall mean the copyright owner or entity authorized by
22
the copyright owner that is granting the License.
23
24
"Legal Entity" shall mean the union of the acting entity and all
25
other entities that control, are controlled by, or are under common
26
control with that entity. For the purposes of this definition,
27
"control" means (i) the power, direct or indirect, to cause the
28
direction or management of such entity, whether by contract or
29
otherwise, or (ii) ownership of fifty percent (50%) or more of the
30
outstanding shares, or (iii) beneficial ownership of such entity.
31
32
"You" (or "Your") shall mean an individual or Legal Entity
33
exercising permissions granted by this License.
34
35
"Source" form shall mean the preferred form for making modifications,
36
including but not limited to software source code, documentation
37
source, and configuration files.
38
39
"Object" form shall mean any form resulting from mechanical
40
transformation or translation of a Source form, including but
41
not limited to compiled object code, generated documentation,
42
and conversions to other media types.
43
44
"Work" shall mean the work of authorship, whether in Source or
45
Object form, made available under the License, as indicated by a
46
copyright notice that is included in or attached to the work
47
(an example is provided in the Appendix below).
48
49
"Derivative Works" shall mean any work, whether in Source or Object
50
form, that is based on (or derived from) the Work and for which the
51
editorial revisions, annotations, elaborations, or other modifications
52
represent, as a whole, an original work of authorship. For the purposes
53
of this License, Derivative Works shall not include works that remain
54
separable from, or merely link (or bind by name) to the interfaces of,
55
the Work and Derivative Works thereof.
56
57
"Contribution" shall mean any work of authorship, including
58
the original version of the Work and any modifications or additions
59
to that Work or Derivative Works thereof, that is intentionally
60
submitted to Licensor for inclusion in the Work by the copyright owner
61
or by an individual or Legal Entity authorized to submit on behalf of
62
the copyright owner. For the purposes of this definition, "submitted"
63
means any form of electronic, verbal, or written communication sent
64
to the Licensor or its representatives, including but not limited to
65
communication on electronic mailing lists, source code control systems,
66
and issue tracking systems that are managed by, or on behalf of, the
67
Licensor for the purpose of discussing and improving the Work, but
68
excluding communication that is conspicuously marked or otherwise
69
designated in writing by the copyright owner as "Not a Contribution."
70
71
"Contributor" shall mean Licensor and any individual or Legal Entity
72
on behalf of whom a Contribution has been received by Licensor and
73
subsequently incorporated within the Work.
74
75
2. Grant of Copyright License. Subject to the terms and conditions of
76
this License, each Contributor hereby grants to You a perpetual,
77
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
copyright license to reproduce, prepare Derivative Works of,
79
publicly display, publicly perform, sublicense, and distribute the
80
Work and such Derivative Works in Source or Object form.
81
82
3. Grant of Patent License. Subject to the terms and conditions of
83
this License, each Contributor hereby grants to You a perpetual,
84
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
85
(except as stated in this section) patent license to make, have made,
86
use, offer to sell, sell, import, and otherwise transfer the Work,
87
where such license applies only to those patent claims licensable
88
by such Contributor that are necessarily infringed by their
89
Contribution(s) alone or by combination of their Contribution(s)
90
with the Work to which such Contribution(s) was submitted. If You
91
institute patent litigation against any entity (including a
92
cross-claim or counterclaim in a lawsuit) alleging that the Work
93
or a Contribution incorporated within the Work constitutes direct
94
or contributory patent infringement, then any patent licenses
95
granted to You under this License for that Work shall terminate
96
as of the date such litigation is filed.
97
98
4. Redistribution. You may reproduce and distribute copies of the
99
Work or Derivative Works thereof in any medium, with or without
100
modifications, and in Source or Object form, provided that You
101
meet the following conditions:
102
103
(a) You must give any other recipients of the Work or
104
Derivative Works a copy of this License; and
105
106
(b) You must cause any modified files to carry prominent notices
107
stating that You changed the files; and
108
109
(c) You must retain, in the Source form of any Derivative Works
110
that You distribute, all copyright, patent, trademark, and
111
attribution notices from the Source form of the Work,
112
excluding those notices that do not pertain to any part of
113
the Derivative Works; and
114
115
(d) If the Work includes a "NOTICE" text file as part of its
116
distribution, then any Derivative Works that You distribute must
117
include a readable copy of the attribution notices contained
118
within such NOTICE file, excluding those notices that do not
119
pertain to any part of the Derivative Works, in at least one
120
of the following places: within a NOTICE text file distributed
121
as part of the Derivative Works; within the Source form or
122
documentation, if provided along with the Derivative Works; or,
123
within a display generated by the Derivative Works, if and
124
wherever such third-party notices normally appear. The contents
125
of the NOTICE file are for informational purposes only and
126
do not modify the License. You may add Your own attribution
127
notices within Derivative Works that You distribute, alongside
128
or as an addendum to the NOTICE text from the Work, provided
129
that such additional attribution notices cannot be construed
130
as modifying the License.
131
132
You may add Your own copyright statement to Your modifications and
133
may provide additional or different license terms and conditions
134
for use, reproduction, or distribution of Your modifications, or
135
for any such Derivative Works as a whole, provided Your use,
136
reproduction, and distribution of the Work otherwise complies with
137
the conditions stated in this License.
138
139
5. Submission of Contributions. Unless You explicitly state otherwise,
140
any Contribution intentionally submitted for inclusion in the Work
141
by You to the Licensor shall be under the terms and conditions of
142
this License, without any additional terms or conditions.
143
Notwithstanding the above, nothing herein shall supersede or modify
144
the terms of any separate license agreement you may have executed
145
with Licensor regarding such Contributions.
146
147
6. Trademarks. This License does not grant permission to use the trade
148
names, trademarks, service marks, or product names of the Licensor,
149
except as required for reasonable and customary use in describing the
150
origin of the Work and reproducing the content of the NOTICE file.
151
152
7. Disclaimer of Warranty. Unless required by applicable law or
153
agreed to in writing, Licensor provides the Work (and each
154
Contributor provides its Contributions) on an "AS IS" BASIS,
155
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
156
implied, including, without limitation, any warranties or conditions
157
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
158
PARTICULAR PURPOSE. You are solely responsible for determining the
159
appropriateness of using or redistributing the Work and assume any
160
risks associated with Your exercise of permissions under this License.
161
162
8. Limitation of Liability. In no event and under no legal theory,
163
whether in tort (including negligence), contract, or otherwise,
164
unless required by applicable law (such as deliberate and grossly
165
negligent acts) or agreed to in writing, shall any Contributor be
166
liable to You for damages, including any direct, indirect, special,
167
incidental, or consequential damages of any character arising as a
168
result of this License or out of the use or inability to use the
169
Work (including but not limited to damages for loss of goodwill,
170
work stoppage, computer failure or malfunction, or any and all
171
other commercial damages or losses), even if such Contributor
172
has been advised of the possibility of such damages.
173
174
9. Accepting Warranty or Additional Liability. While redistributing
175
the Work or Derivative Works thereof, You may choose to offer,
176
and charge a fee for, acceptance of support, warranty, indemnity,
177
or other liability obligations and/or rights consistent with this
178
License. However, in accepting such obligations, You may act only
179
on Your own behalf and on Your sole responsibility, not on behalf
180
of any other Contributor, and only if You agree to indemnify,
181
defend, and hold each Contributor harmless for any liability
182
incurred by, or claims asserted against, such Contributor by reason
183
of your accepting any such warranty or additional liability.
184
185
END OF TERMS AND CONDITIONS
186
187
APPENDIX: How to apply the Apache License to your work.
188
189
To apply the Apache License to your work, attach the following
190
boilerplate notice, with the fields enclosed by brackets "[]"
191
replaced with your own identifying information. (Don't include
192
the brackets!) The text should be enclosed in the appropriate
193
comment syntax for the file format. We also recommend that a
194
file or class name and description of purpose be included on the
195
same "printed page" as the copyright notice for easier
196
identification within third-party archives.
197
198
Copyright [yyyy] [name of copyright owner]
199
200
Licensed under the Apache License, Version 2.0 (the "License");
201
you may not use this file except in compliance with the License.
202
You may obtain a copy of the License at
203
204
http://www.apache.org/licenses/LICENSE-2.0
205
206
Unless required by applicable law or agreed to in writing, software
207
distributed under the License is distributed on an "AS IS" BASIS,
208
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
209
See the License for the specific language governing permissions and
210
limitations under the License.
211
'''
212
213
214
import torch
215
from torch import nn
216
import torch.cuda.comm
217
import copy
218
import gc
219
import os
220
import sys
221
import itertools
222
import bisect
223
import random
224
import utils
225
from typing import Dict, List, Optional, Union
226
227
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
228
229
from transformers.utils import logging
230
logger = logging.get_logger(__name__)
231
232
233
breakmodel = True
234
gpu_blocks = []
235
disk_blocks = 0
236
primary_device = 0 if torch.cuda.device_count() > 0 else "cpu"
237
238
239
if utils.HAS_ACCELERATE:
240
from accelerate.hooks import attach_align_device_hook_on_blocks
241
from accelerate.utils import OffloadedWeightsLoader, check_device_map, extract_submodules_state_dict, offload_state_dict
242
from accelerate import dispatch_model
243
244
def dispatch_model_ex(
245
model: nn.Module,
246
device_map: Dict[str, Union[str, int, torch.device]],
247
main_device: Optional[torch.device] = None,
248
state_dict: Optional[Dict[str, torch.Tensor]] = None,
249
offload_dir: Union[str, os.PathLike] = None,
250
offload_buffers: bool = False,
251
**kwargs,
252
):
253
"""
254
This is a modified version of
255
https://github.com/huggingface/accelerate/blob/eeaba598f455fbd2c48661d7e816d3ff25ab050b/src/accelerate/big_modeling.py#L130
256
that still works when the main device is the CPU.
257
258
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
259
the CPU or even the disk.
260
261
Args:
262
model (`torch.nn.Module`):
263
The model to dispatch.
264
device_map (`Dict[str, Union[str, int, torch.device]]`):
265
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
266
`"disk"` is accepted even if it's not a proper value for `torch.device`.
267
main_device (`str`, `int` or `torch.device`, *optional*):
268
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
269
`"disk"`.
270
state_dict (`Dict[str, torch.Tensor]`, *optional*):
271
The state dict of the part of the model that will be kept on CPU.
272
offload_dir (`str` or `os.PathLike`):
273
The folder in which to offload the model weights (or where the model weights are already offloaded).
274
offload_buffers (`bool`, *optional*, defaults to `False`):
275
Whether or not to offload the buffers with the model parameters.
276
preload_module_classes (`List[str]`, *optional*):
277
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
278
of the forward. This should only be used for classes that have submodules which are registered but not
279
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
280
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
281
"""
282
if main_device != "cpu":
283
return dispatch_model(model, device_map, main_device, state_dict, offload_dir=offload_dir, offload_buffers=offload_buffers, **kwargs)
284
285
# Error early if the device map is incomplete.
286
check_device_map(model, device_map)
287
288
offload_devices = ["cpu", "disk"] if main_device != "cpu" else ["disk"]
289
290
if main_device is None:
291
main_device = [d for d in device_map.values() if d not in offload_devices][0]
292
293
cpu_modules = [name for name, device in device_map.items() if device == "cpu"] if main_device != "cpu" else []
294
if state_dict is None and len(cpu_modules) > 0:
295
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
296
297
disk_modules = [name for name, device in device_map.items() if device == "disk"]
298
if offload_dir is None and len(disk_modules) > 0:
299
raise ValueError(
300
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
301
f"need to be offloaded: {', '.join(disk_modules)}."
302
)
303
if len(disk_modules) > 0 and (
304
not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json"))
305
):
306
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
307
offload_state_dict(offload_dir, disk_state_dict)
308
309
execution_device = {
310
name: main_device if device in offload_devices else device for name, device in device_map.items()
311
}
312
offload = {name: device in offload_devices for name, device in device_map.items()}
313
save_folder = offload_dir if len(disk_modules) > 0 else None
314
if state_dict is not None or save_folder is not None:
315
weights_map = OffloadedWeightsLoader(state_dict=state_dict, save_folder=save_folder)
316
else:
317
weights_map = None
318
319
attach_align_device_hook_on_blocks(
320
model,
321
execution_device=execution_device,
322
offload=offload,
323
offload_buffers=offload_buffers,
324
weights_map=weights_map,
325
**kwargs,
326
)
327
model.hf_device_map = device_map
328
return model
329
330
331
# Copied from transformers.models.bart.modeling_bart._expand_mask
332
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
333
"""
334
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
335
"""
336
bsz, src_len = mask.size()
337
tgt_len = tgt_len if tgt_len is not None else src_len
338
339
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
340
341
inverted_mask = 1.0 - expanded_mask
342
343
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
344
345
346
def move_hidden_layers(transformer, h=None):
347
if h is None:
348
h = transformer.h
349
350
assert len(gpu_blocks) <= torch.cuda.device_count()
351
assert sum(gpu_blocks) <= len(h)
352
ram_blocks = len(h) - sum(gpu_blocks)
353
354
transformer.extrastorage = {}
355
torch.cuda.empty_cache()
356
357
able_to_pin_layers = True
358
for i in range(ram_blocks):
359
h[i].to("cpu")
360
transformer.extrastorage[i] = copy.deepcopy(h[i])
361
smalltensor = torch.tensor(0).to(primary_device)
362
for param1 in h[i].parameters():
363
param1.data = smalltensor
364
h[i].to(primary_device)
365
for param in transformer.extrastorage[i].parameters():
366
param.requires_grad = False
367
param.data = param.data.detach()
368
if able_to_pin_layers:
369
try:
370
param.data = param.data.pin_memory()
371
except:
372
able_to_pin_layers = False
373
print(f"WARNING: You only have enough shared GPU memory for {i} out of {ram_blocks} CPU layers. Expect suboptimal speed.", file=sys.stderr)
374
gc.collect()
375
torch.cuda.empty_cache()
376
377
if ram_blocks:
378
for param1,param2 in zip(h[0].parameters(),transformer.extrastorage[0].parameters()):
379
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
380
381
for param1,param2 in zip(h[ram_blocks-1].parameters(),transformer.extrastorage[ram_blocks-1].parameters()):
382
param1.data = param2.data.to(primary_device, non_blocking=False).detach()
383
384
i = ram_blocks
385
for j in range(len(gpu_blocks)):
386
for _ in range(gpu_blocks[j]):
387
h[i].to(j)
388
i += 1
389
390
391
def new_forward_neo(
392
self,
393
input_ids=None,
394
past_key_values=None,
395
attention_mask=None,
396
token_type_ids=None,
397
position_ids=None,
398
head_mask=None,
399
inputs_embeds=None,
400
use_cache=None,
401
output_attentions=None,
402
output_hidden_states=None,
403
return_dict=None,
404
embs=None,
405
):
406
assert len(gpu_blocks) <= torch.cuda.device_count()
407
assert sum(gpu_blocks) <= len(self.h)
408
ram_blocks = len(self.h) - sum(gpu_blocks)
409
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
410
411
412
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
output_hidden_states = (
414
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
)
416
use_cache = use_cache if use_cache is not None else self.config.use_cache
417
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
418
419
if input_ids is not None and inputs_embeds is not None:
420
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
421
elif input_ids is not None:
422
input_shape = input_ids.size()
423
input_ids = input_ids.view(-1, input_shape[-1])
424
batch_size = input_ids.shape[0]
425
elif inputs_embeds is not None:
426
input_shape = inputs_embeds.size()[:-1]
427
batch_size = inputs_embeds.shape[0]
428
else:
429
raise ValueError("You have to specify either input_ids or inputs_embeds")
430
431
device = input_ids.device if input_ids is not None else inputs_embeds.device
432
433
if token_type_ids is not None:
434
token_type_ids = token_type_ids.view(-1, input_shape[-1])
435
if position_ids is not None:
436
position_ids = position_ids.view(-1, input_shape[-1])
437
438
if past_key_values is None:
439
past_length = 0
440
past_key_values = tuple([None] * len(self.h))
441
else:
442
past_length = past_key_values[0][0].size(-2)
443
444
device = primary_device if breakmodel else input_ids.device if input_ids is not None else inputs_embeds.device
445
if position_ids is None:
446
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
447
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
448
449
# Attention mask.
450
if attention_mask is not None:
451
assert batch_size > 0, "batch_size has to be defined and > 0"
452
attention_mask = attention_mask.view(batch_size, -1)
453
# We create a 3D attention mask from a 2D tensor mask.
454
# Sizes are [batch_size, 1, 1, to_seq_length]
455
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
456
# this attention mask is more simple than the triangular masking of causal attention
457
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
458
attention_mask = attention_mask[:, None, None, :]
459
460
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
461
# masked positions, this operation will create a tensor which is 0.0 for
462
# positions we want to attend and -10000.0 for masked positions.
463
# Since we are adding it to the raw scores before the softmax, this is
464
# effectively the same as removing these entirely.
465
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
466
attention_mask = (1.0 - attention_mask) * -10000.0
467
468
# Prepare head mask if needed
469
# 1.0 in head_mask indicate we keep the head
470
# attention_probs has shape bsz x num_heads x N x N
471
# head_mask has shape n_layer x batch x num_heads x N x N
472
head_mask = self.get_head_mask(head_mask, getattr(self.config, "num_layers", None) or self.config.n_layer)
473
474
if inputs_embeds is None:
475
if breakmodel:
476
input_ids = input_ids.to(primary_device)
477
inputs_embeds = self.wte(input_ids)
478
479
if embs is not None and not (use_cache is not None and use_cache and past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None):
480
offset = 0
481
for pos, emb in embs:
482
pos += offset
483
if len(emb.shape) == 2:
484
emb = emb.repeat(input_shape[0], 1, 1)
485
inputs_embeds[:, pos:pos+emb.shape[1]] = emb
486
offset += emb.shape[1]
487
488
if getattr(self, "wpe", None) is None:
489
hidden_states = inputs_embeds
490
else:
491
if breakmodel:
492
position_ids = position_ids.to(primary_device)
493
position_embeds = self.wpe(position_ids)
494
if breakmodel:
495
position_embeds = position_embeds.to(primary_device)
496
hidden_states = inputs_embeds + position_embeds
497
498
if token_type_ids is not None:
499
token_type_embeds = self.wte(token_type_ids)
500
hidden_states = hidden_states + token_type_embeds
501
502
hidden_states = self.drop(hidden_states)
503
504
output_shape = input_shape + (hidden_states.size(-1),)
505
506
presents = () if use_cache else None
507
all_self_attentions = () if output_attentions else None
508
all_hidden_states = () if output_hidden_states else None
509
510
if breakmodel and ram_blocks:
511
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
512
513
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
514
515
if breakmodel:
516
if i in range(ram_blocks):
517
index1 = (i+1)%ram_blocks
518
for param1,param2 in zip(self.h[index1].parameters(),self.h[(i-1)%ram_blocks].parameters()):
519
param1.data = param2.data
520
for param1,param2 in zip(self.h[index1].parameters(),self.extrastorage[index1].parameters()):
521
with torch.cuda.stream(copystream):
522
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
523
524
if output_hidden_states:
525
all_hidden_states = all_hidden_states + (hidden_states.cpu(),)
526
527
if getattr(self.config, "gradient_checkpointing", False) and self.training:
528
529
if use_cache:
530
logger.warning(
531
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
532
)
533
use_cache = False
534
535
def create_custom_forward(module):
536
def custom_forward(*inputs):
537
# None for past_key_value
538
return module(*inputs, use_cache, output_attentions)
539
540
return custom_forward
541
542
outputs = torch.utils.checkpoint.checkpoint(
543
create_custom_forward(block),
544
hidden_states,
545
None,
546
attention_mask,
547
head_mask[i],
548
)
549
else:
550
if breakmodel:
551
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
552
outputs = block(
553
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
554
layer_past=tuple(v.to(device) for v in layer_past if v is not None) if breakmodel and layer_past is not None and i >= ram_blocks and len(layer_past) and layer_past[0].device.index != device else layer_past,
555
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
556
head_mask=head_mask[i].to(device) if breakmodel and head_mask[i] is not None else head_mask[i],
557
use_cache=use_cache,
558
output_attentions=output_attentions,
559
)
560
561
hidden_states = outputs[0]
562
if use_cache is True:
563
presents = presents + (outputs[1],)
564
565
if output_attentions:
566
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
567
568
569
if breakmodel:
570
if i in range(ram_blocks):
571
torch.cuda.synchronize()
572
torch.cuda.empty_cache()
573
574
if breakmodel:
575
if ram_blocks:
576
del copystream
577
torch.cuda.empty_cache()
578
hidden_states = hidden_states.to(primary_device)
579
hidden_states = self.ln_f(hidden_states)
580
if breakmodel:
581
hidden_states = hidden_states.to(primary_device)
582
583
hidden_states = hidden_states.view(*output_shape)
584
# Add last hidden state
585
if output_hidden_states:
586
all_hidden_states = all_hidden_states + (hidden_states,)
587
588
if not return_dict:
589
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
590
return BaseModelOutputWithPast(
591
last_hidden_state=hidden_states,
592
past_key_values=presents,
593
hidden_states=all_hidden_states,
594
attentions=all_self_attentions,
595
)
596
597
598
def new_forward_xglm(
599
self,
600
input_ids=None,
601
attention_mask=None,
602
encoder_hidden_states=None,
603
encoder_attention_mask=None,
604
head_mask=None,
605
cross_attn_head_mask=None,
606
past_key_values=None,
607
inputs_embeds=None,
608
use_cache=None,
609
output_attentions=None,
610
output_hidden_states=None,
611
return_dict=None,
612
):
613
assert len(gpu_blocks) <= torch.cuda.device_count()
614
assert sum(gpu_blocks) <= len(self.layers)
615
ram_blocks = len(self.layers) - sum(gpu_blocks)
616
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
617
618
619
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
620
output_hidden_states = (
621
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
622
)
623
use_cache = use_cache if use_cache is not None else self.config.use_cache
624
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
625
626
# retrieve input_ids and inputs_embeds
627
if input_ids is not None and inputs_embeds is not None:
628
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
629
elif input_ids is not None:
630
input_shape = input_ids.size()
631
input_ids = input_ids.view(-1, input_shape[-1])
632
elif inputs_embeds is not None:
633
input_shape = inputs_embeds.size()[:-1]
634
else:
635
raise ValueError("You have to specify either input_ids or inputs_embeds")
636
637
# past_key_values_length
638
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
639
640
if inputs_embeds is None:
641
if breakmodel:
642
input_ids = input_ids.to(primary_device)
643
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
644
645
attention_mask = self._prepare_decoder_attention_mask(
646
attention_mask, input_shape, inputs_embeds, past_key_values_length
647
)
648
649
# expand encoder attention mask
650
if encoder_hidden_states is not None and encoder_attention_mask is not None:
651
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
652
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
653
654
# embed positions
655
if breakmodel:
656
inputs_embeds = inputs_embeds.to(primary_device)
657
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
658
if breakmodel:
659
positions = positions.to(primary_device)
660
661
hidden_states = inputs_embeds + positions
662
663
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
664
665
# decoder layers
666
all_hidden_states = () if output_hidden_states else None
667
all_self_attns = () if output_attentions else None
668
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
669
next_decoder_cache = () if use_cache else None
670
671
if breakmodel and ram_blocks:
672
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
673
674
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
675
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
676
if attn_mask is not None:
677
assert attn_mask.size()[0] == (
678
len(self.layers)
679
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
680
for idx, decoder_layer in enumerate(self.layers):
681
i = idx
682
if breakmodel:
683
if i in range(ram_blocks):
684
index1 = (i+1)%ram_blocks
685
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
686
param1.data = param2.data
687
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
688
with torch.cuda.stream(copystream):
689
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
690
691
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
692
if output_hidden_states:
693
all_hidden_states += (hidden_states,)
694
dropout_probability = random.uniform(0, 1)
695
if self.training and (dropout_probability < self.layerdrop):
696
continue
697
698
past_key_value = past_key_values[idx] if past_key_values is not None else None
699
700
if self.gradient_checkpointing and self.training:
701
702
if use_cache:
703
logger.warning(
704
"`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
705
)
706
use_cache = False
707
708
def create_custom_forward(module):
709
def custom_forward(*inputs):
710
# None for past_key_value
711
return module(*inputs, output_attentions, use_cache)
712
713
return custom_forward
714
715
layer_outputs = torch.utils.checkpoint.checkpoint(
716
create_custom_forward(decoder_layer),
717
hidden_states,
718
attention_mask,
719
encoder_hidden_states,
720
encoder_attention_mask,
721
head_mask[idx] if head_mask is not None else None,
722
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
723
None,
724
)
725
else:
726
if breakmodel:
727
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
728
layer_outputs = decoder_layer(
729
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
730
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
731
encoder_hidden_states=encoder_hidden_states.to(device) if breakmodel and encoder_hidden_states is not None else encoder_hidden_states,
732
encoder_attention_mask=encoder_attention_mask.to(device) if breakmodel and encoder_attention_mask is not None else encoder_attention_mask,
733
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
734
cross_attn_layer_head_mask=(
735
(cross_attn_head_mask[idx].to(device) if breakmodel and cross_attn_head_mask[idx] is not None else cross_attn_head_mask[idx]) if cross_attn_head_mask is not None else None
736
),
737
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
738
output_attentions=output_attentions,
739
use_cache=use_cache,
740
)
741
hidden_states = layer_outputs[0]
742
743
if use_cache:
744
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
745
746
if output_attentions:
747
all_self_attns += (layer_outputs[1],)
748
749
if encoder_hidden_states is not None:
750
all_cross_attentions += (layer_outputs[2],)
751
752
if breakmodel:
753
if i in range(ram_blocks):
754
torch.cuda.synchronize()
755
torch.cuda.empty_cache()
756
757
if breakmodel:
758
if ram_blocks:
759
del copystream
760
torch.cuda.empty_cache()
761
hidden_states = hidden_states.to(primary_device)
762
hidden_states = self.layer_norm(hidden_states)
763
if breakmodel:
764
hidden_states = hidden_states.to(primary_device)
765
766
# add hidden states from the last decoder layer
767
if output_hidden_states:
768
all_hidden_states += (hidden_states,)
769
770
next_cache = next_decoder_cache if use_cache else None
771
if not return_dict:
772
return tuple(
773
v
774
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
775
if v is not None
776
)
777
return BaseModelOutputWithPastAndCrossAttentions(
778
last_hidden_state=hidden_states,
779
past_key_values=next_cache,
780
hidden_states=all_hidden_states,
781
attentions=all_self_attns,
782
cross_attentions=all_cross_attentions,
783
)
784
785
786
def new_forward_opt(
787
self,
788
input_ids=None,
789
attention_mask=None,
790
head_mask=None,
791
past_key_values=None,
792
inputs_embeds=None,
793
use_cache=None,
794
output_attentions=None,
795
output_hidden_states=None,
796
return_dict=None,
797
):
798
assert len(gpu_blocks) <= torch.cuda.device_count()
799
assert sum(gpu_blocks) <= len(self.layers)
800
ram_blocks = len(self.layers) - sum(gpu_blocks)
801
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
802
803
804
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
805
output_hidden_states = (
806
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
807
)
808
use_cache = use_cache if use_cache is not None else self.config.use_cache
809
810
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
811
812
# retrieve input_ids and inputs_embeds
813
if input_ids is not None and inputs_embeds is not None:
814
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
815
elif input_ids is not None:
816
input_shape = input_ids.size()
817
input_ids = input_ids.view(-1, input_shape[-1])
818
elif inputs_embeds is not None:
819
input_shape = inputs_embeds.size()[:-1]
820
else:
821
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
822
823
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
824
825
if inputs_embeds is None:
826
if breakmodel:
827
input_ids = input_ids.to(primary_device)
828
inputs_embeds = self.embed_tokens(input_ids)
829
830
# embed positions
831
if breakmodel:
832
inputs_embeds = inputs_embeds.to(primary_device)
833
if attention_mask is None:
834
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
835
836
positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :]
837
if breakmodel:
838
positions = positions.to(primary_device)
839
840
attention_mask = self._prepare_decoder_attention_mask(
841
attention_mask, input_shape, inputs_embeds, past_key_values_length
842
)
843
844
if self.project_in is not None:
845
inputs_embeds = self.project_in(inputs_embeds)
846
847
hidden_states = inputs_embeds + positions
848
849
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
850
851
# decoder layers
852
all_hidden_states = () if output_hidden_states else None
853
all_self_attns = () if output_attentions else None
854
next_decoder_cache = () if use_cache else None
855
856
if breakmodel and ram_blocks:
857
copystream = torch.cuda.Stream(device=primary_device, priority=-1)
858
859
# check if head_mask has a correct number of layers specified if desired
860
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
861
if attn_mask is not None:
862
if attn_mask.size()[0] != (len(self.layers)):
863
raise ValueError(
864
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
865
f" {head_mask.size()[0]}."
866
)
867
868
for idx, decoder_layer in enumerate(self.layers):
869
i = idx
870
if breakmodel:
871
if i in range(ram_blocks):
872
index1 = (i+1)%ram_blocks
873
for param1,param2 in zip(self.layers[index1].parameters(),self.layers[(i-1)%ram_blocks].parameters()):
874
param1.data = param2.data
875
for param1,param2 in zip(self.layers[index1].parameters(),self.extrastorage[index1].parameters()):
876
with torch.cuda.stream(copystream):
877
torch.cuda.comm.broadcast(param2.data,out = [param1.data])
878
879
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
880
if output_hidden_states:
881
all_hidden_states += (hidden_states,)
882
dropout_probability = random.uniform(0, 1)
883
if self.training and (dropout_probability < self.layerdrop):
884
continue
885
886
past_key_value = past_key_values[idx] if past_key_values is not None else None
887
888
if self.gradient_checkpointing and self.training:
889
890
if use_cache:
891
logger.warning(
892
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
893
)
894
use_cache = False
895
896
def create_custom_forward(module):
897
def custom_forward(*inputs):
898
# None for past_key_value
899
return module(*inputs, output_attentions, None)
900
901
return custom_forward
902
903
layer_outputs = torch.utils.checkpoint.checkpoint(
904
create_custom_forward(decoder_layer),
905
hidden_states,
906
attention_mask,
907
head_mask[idx] if head_mask is not None else None,
908
None,
909
)
910
else:
911
if breakmodel:
912
device = primary_device if i < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, i - ram_blocks)
913
layer_outputs = decoder_layer(
914
hidden_states.to(device) if breakmodel and hidden_states is not None else hidden_states,
915
attention_mask=attention_mask.to(device) if breakmodel and attention_mask is not None else attention_mask,
916
layer_head_mask=((head_mask[idx].to(device) if breakmodel and head_mask[idx] is not None else head_mask[idx]) if head_mask is not None else None),
917
past_key_value=tuple(v.to(device) for v in past_key_value if v is not None) if breakmodel and past_key_value is not None and i >= ram_blocks and len(past_key_value) and past_key_value[0].device.index != device else past_key_value,
918
output_attentions=output_attentions,
919
use_cache=use_cache,
920
)
921
922
hidden_states = layer_outputs[0]
923
924
if use_cache:
925
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
926
927
if output_attentions:
928
all_self_attns += (layer_outputs[1],)
929
930
if breakmodel:
931
if i in range(ram_blocks):
932
torch.cuda.synchronize()
933
torch.cuda.empty_cache()
934
935
if breakmodel:
936
if ram_blocks:
937
del copystream
938
torch.cuda.empty_cache()
939
hidden_states = hidden_states.to(primary_device)
940
if self.project_out is not None:
941
hidden_states = self.project_out(hidden_states)
942
if breakmodel:
943
hidden_states = hidden_states.to(primary_device)
944
945
# add hidden states from the last decoder layer
946
if output_hidden_states:
947
all_hidden_states += (hidden_states,)
948
949
next_cache = next_decoder_cache if use_cache else None
950
if not return_dict:
951
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
952
return BaseModelOutputWithPast(
953
last_hidden_state=hidden_states,
954
past_key_values=next_cache,
955
hidden_states=all_hidden_states,
956
attentions=all_self_attns,
957
)
958
959