Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/tpu_mtj_backend.py
471 views
1
'''
2
This file is AGPL-licensed.
3
4
Some of the code in this file is from Clover Edition:
5
https://github.com/cloveranon/Clover-Edition/blob/master/aidungeon/gpt2generator.py
6
7
The license for Clover Edition is shown below:
8
9
Copyright (c) 2019 Nick Walton
10
11
Permission is hereby granted, free of charge, to any person obtaining a copy
12
of this software and associated documentation files (the "Software"), to deal
13
in the Software without restriction, including without limitation the rights
14
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
copies of the Software, and to permit persons to whom the Software is
16
furnished to do so, subject to the following conditions:
17
18
The above copyright notice and this permission notice shall be included in all
19
copies or substantial portions of the Software.
20
21
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
SOFTWARE.
28
'''
29
30
import utils
31
32
import multiprocessing
33
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
34
import progressbar
35
import time
36
import os
37
import sys
38
import json
39
import zipfile
40
import requests
41
import random
42
import jax
43
import jax.dlpack
44
from jax.config import config
45
from jax.experimental import maps
46
import jax.numpy as jnp
47
import numpy as np
48
import haiku as hk
49
from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForCausalLM, GPTNeoForCausalLM
50
from tokenizers import Tokenizer
51
from mesh_transformer.checkpoint import read_ckpt_lowmem
52
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
53
from mesh_transformer.util import to_bf16
54
55
56
params: Dict[str, Any] = {}
57
58
__seed = random.randrange(2**64)
59
rng = random.Random(__seed)
60
61
62
def get_rng_seed():
63
return __seed
64
65
def set_rng_seed(seed: int):
66
global __seed, rng
67
rng = random.Random(seed)
68
__seed = seed
69
return seed
70
71
def randomize_rng_seed():
72
return set_rng_seed(random.randrange(2**64))
73
74
def get_rng_state():
75
return rng
76
77
def set_rng_state(state):
78
global rng
79
rng = state
80
81
def new_rng_state(seed: int):
82
return random.Random(seed)
83
84
def warper_callback(logits) -> np.array:
85
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
86
87
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
88
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
89
90
def settings_callback() -> dict:
91
return {
92
"sampler_order": utils.default_sampler_order.copy(),
93
"top_p": 0.9,
94
"temp": 0.5,
95
"top_k": 0,
96
"tfs": 1.0,
97
"typical": 1.0,
98
"top_a": 0.0,
99
"repetition_penalty": 1.0,
100
"rpslope": 0.0,
101
"rprange": 0,
102
}
103
104
def started_compiling_callback() -> None:
105
pass
106
107
def stopped_compiling_callback() -> None:
108
pass
109
110
def compiling_callback() -> None:
111
pass
112
113
114
def show_spinner():
115
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='â–ˆ')])
116
i = 0
117
while True:
118
bar.update(i)
119
time.sleep(0.1)
120
i += 1
121
122
123
__F = TypeVar("__F", bound=Callable)
124
__T = TypeVar("__T")
125
126
def __move_xmap(f: __F, out_axis: str) -> __F:
127
return maps.xmap(
128
f,
129
in_axes=(["shard", ...], ["batch", ...]),
130
out_axes=[out_axis, ...],
131
axis_resources={'shard': 'mp', 'batch': 'dp'},
132
)
133
134
def __shard_xmap(batch_dim=1):
135
xmap = __move_xmap(lambda s, b: s, "shard")
136
def inner(x: __T) -> __T:
137
return xmap(x, np.empty(batch_dim))
138
return inner
139
140
def __batch_xmap(shard_dim=1):
141
xmap = __move_xmap(lambda s, b: b, "batch")
142
def inner(x: __T) -> __T:
143
return xmap(np.empty(shard_dim), x)
144
return inner
145
146
147
class _EmptyState(NamedTuple):
148
pass
149
150
class _DummyOptimizer:
151
def init(*args, **kwargs):
152
return _EmptyState()
153
154
155
def apply_repetition_penalty_dynamic(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
156
'''
157
This gets called by generate_loop_fn to apply repetition penalty
158
to the 1D array logits using the provided 1D array of tokens to penalize
159
'''
160
tokens = np.minimum(tokens, params["n_vocab"]-1) # https://github.com/google/jax/issues/3774
161
rpslope = np.int32(rpslope)
162
rprange = np.int32(rprange)
163
clipped_rprange = rprange if rprange > 0 else tokens.shape[-1]
164
penalty_arange = np.roll(np.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
165
# Make a new array with the same length as the tokens array but with
166
# each element replaced by the value at the corresponding index in the
167
# logits array; e.g.
168
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
169
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
170
penalty_logits = np.take(logits, tokens)
171
# Repetition penalty slope
172
if rpslope != 0.0 and rprange > 0:
173
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
174
_penalty = (rpslope * _penalty) / (1 + np.abs(_penalty) * (rpslope - 1))
175
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
176
repetition_penalty = _penalty
177
# Divide positive values by repetition_penalty and multiply negative
178
# values by repetition_penalty (the academic publication that described
179
# this technique actually just only divided, but that would cause tokens
180
# with negative logits to become more likely, which is obviously wrong)
181
penalty_logits = np.where(
182
penalty_arange >= 0,
183
np.where(
184
penalty_logits > 0,
185
penalty_logits/repetition_penalty,
186
penalty_logits*repetition_penalty,
187
),
188
penalty_logits,
189
)
190
# Finally, put those penalized logit values back into their original
191
# positions in the logits array
192
logits[tokens] = penalty_logits
193
return logits
194
195
def kobold_sample_dynamic(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
196
'''
197
This gets called by generate_loop_fn to apply a series of 6 filters
198
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
199
before picking one token using the modified logits
200
'''
201
# Top-k (keep only the k tokens with the highest logits and remove
202
# the rest, by setting their logits to negative infinity)
203
def top_k_filter(logits):
204
# After sorting the logits array in descending order,
205
# sorted_indices_to_remove is a 1D array that is True for tokens
206
# in the sorted logits array we want to remove and False for ones
207
# we want to keep, in this case the first top_k elements will be
208
# False and the rest will be True
209
sorted_indices_to_remove = np.arange(len(logits)) >= top_k
210
# Unsort the logits array back to its original configuration and
211
# remove tokens we need to remove
212
_, indices_to_remove = jax.lax.sort_key_val(
213
np.argsort(-logits),
214
sorted_indices_to_remove,
215
)
216
return np.where(indices_to_remove, -np.inf, logits)
217
# Top-a (remove all tokens that have softmax probability less than
218
# a*m^2 where m is the maximum softmax probability)
219
def top_a_filter(logits):
220
# Replace every element in the logits array
221
# with e (Euler's number) to the power of that element, and divide
222
# each element of the new array by the sum of the elements in the
223
# new array
224
probabilities = np.array(jax.nn.softmax(logits), copy=True)
225
# Find the largest probability
226
probs_max = probabilities.max()
227
# Remove tokens
228
return np.where(probabilities < probs_max * probs_max * top_a, -np.inf, logits)
229
# Top-p (after sorting the remaining tokens again in descending order of
230
# logit, remove the ones that have cumulative softmax probability
231
# greater than p)
232
def top_p_filter(logits):
233
# Sort the logits array in descending order, replace every element
234
# with e (Euler's number) to the power of that element, and divide
235
# each element of the new array by the sum of the elements in the
236
# new array
237
sorted_logits = -np.sort(-logits)
238
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
239
# Calculate cumulative_probabilities as the prefix-sum array of
240
# probabilities
241
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
242
# We want to remove tokens with cumulative probability higher
243
# than top_p
244
sorted_indices_to_remove = cumulative_probabilities > top_p
245
# Don't ever remove the token with the highest logit, even if
246
# the probability is higher than top_p
247
sorted_indices_to_remove[0] = False
248
# Unsort and remove
249
_, indices_to_remove = jax.lax.sort_key_val(
250
np.argsort(-logits),
251
sorted_indices_to_remove,
252
)
253
return np.where(indices_to_remove, -np.inf, logits)
254
# Tail free sampling (basically top-p a second time on remaining tokens
255
# except it's the "cumulative normalized absolute second finite
256
# differences of the softmax probabilities" instead of just the
257
# cumulative softmax probabilities)
258
def tail_free_filter(logits):
259
# Sort in descending order
260
sorted_logits = -np.sort(-logits)
261
# Softmax again
262
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
263
# Calculate the second finite differences of that array (i.e.
264
# calculate the difference array and then calculate the difference
265
# array of the difference array)
266
d2 = np.diff(np.diff(probabilities))
267
# Get the absolute values of all those second finite differences
268
d2 = np.abs(d2)
269
# Normalize (all elements in the array are divided by the sum of the
270
# array's elements)
271
d2 = d2 / d2.sum(axis=-1, keepdims=True)
272
# Get the prefix-sum array
273
cumulative_d2 = np.cumsum(d2, axis=-1)
274
# We will remove the tokens with a cumulative normalized absolute
275
# second finite difference larger than the TFS value
276
sorted_indices_to_remove = cumulative_d2 > tfs
277
# Don't remove the token with the highest logit
278
sorted_indices_to_remove[0] = False
279
# Since the d2 array has two fewer elements than the logits array,
280
# we'll add two extra Trues to the end
281
sorted_indices_to_remove = np.pad(
282
sorted_indices_to_remove,
283
(0, 2),
284
constant_values=True,
285
)
286
# Unsort and remove
287
_, indices_to_remove = jax.lax.sort_key_val(
288
np.argsort(-logits),
289
sorted_indices_to_remove,
290
)
291
return np.where(indices_to_remove, -np.inf, logits)
292
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
293
def typical_filter(logits):
294
# Compute softmax probabilities and the natural logarithms of them
295
probs = jax.nn.softmax(logits)
296
with np.errstate(divide="ignore"):
297
log_probs = np.log(probs)
298
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
299
# in the set of softmax probabilities of the logits
300
neg_entropy = np.nansum(probs * log_probs, axis=-1, keepdims=True)
301
# Determine absolute difference between the negative entropy and the
302
# log probabilities
303
entropy_deviation = np.abs(neg_entropy - log_probs)
304
# Keep certain tokens such that the sum of the entropy_deviation of the
305
# kept tokens is the smallest possible value such that the sum of the
306
# softmax probabilities of the kept tokens is at least the threshold
307
# value (by sorting the tokens in ascending order of entropy_deviation
308
# and then keeping the smallest possible number of tokens from the
309
# beginning such that sum of softmax probabilities is at or above the
310
# threshold)
311
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
312
sorted_indices_to_remove = np.cumsum(sorted_logits, axis=-1) >= typical
313
sorted_indices_to_remove = np.roll(sorted_indices_to_remove, 1, axis=-1)
314
sorted_indices_to_remove[0] = False
315
# Unsort and remove
316
_, indices_to_remove = jax.lax.sort_key_val(
317
jnp.argsort(entropy_deviation),
318
sorted_indices_to_remove,
319
)
320
return np.where(indices_to_remove, -jnp.inf, logits)
321
# Temperature (just divide the logits by the temperature)
322
def temp_filter(logits):
323
return logits / temp
324
for k in sampler_order:
325
if k == 0 and top_k > 0: logits = top_k_filter(logits)
326
if k == 1 and top_a > 0.0: logits = top_a_filter(logits)
327
if k == 2 and top_p < 1.0: logits = top_p_filter(logits)
328
if k == 3 and tfs < 1.0: logits = tail_free_filter(logits)
329
if k == 4 and typical < 1.0: logits = typical_filter(logits)
330
if k == 5 and temp != 1.0: logits = temp_filter(logits)
331
if k == 6 and rpargs[1] != 1.0: logits = apply_repetition_penalty_dynamic(logits, *rpargs)
332
# Finally, pick one token using the softmax thingy again (it gives
333
# an array whose elements sum to 1 so it can be used nicely as a
334
# probability distribution)
335
return jax.random.categorical(key, logits, -1).astype(np.uint32)
336
337
def apply_repetition_penalty_static(logits, tokens, repetition_penalty, generated_index, gen_length, rpslope, rprange):
338
'''
339
This gets called by generate_loop_fn to apply repetition penalty
340
to the 1D array logits using the provided 1D array of tokens to penalize
341
'''
342
rpslope = jnp.int32(rpslope)
343
rprange = jnp.int32(rprange)
344
clipped_rprange = jax.lax.cond(rprange > 0, lambda x: x, lambda x: tokens.shape[-1], rprange)
345
penalty_arange = jnp.roll(jnp.arange(tokens.shape[-1]) + (clipped_rprange - tokens.shape[-1]), generated_index, axis=-1)
346
# Make a new array with the same length as the tokens array but with
347
# each element replaced by the value at the corresponding index in the
348
# logits array; e.g.
349
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
350
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
351
penalty_logits = jnp.take(logits, tokens)
352
# Repetition penalty slope
353
def apply_slope(carry):
354
repetition_penalty, rprange = carry
355
_penalty = (penalty_arange/(rprange - 1)) * 2 - 1
356
_penalty = (rpslope * _penalty) / (1 + jnp.abs(_penalty) * (rpslope - 1))
357
_penalty = 1 + ((_penalty + 1) / 2) * (repetition_penalty - 1)
358
return _penalty
359
repetition_penalty = jax.lax.cond(
360
(rpslope != 0.0) & (rprange > 0), # Not a typo; do not use `and` here, it makes JAX crash
361
apply_slope,
362
lambda carry: jnp.full(tokens.shape, carry[0]),
363
(repetition_penalty, rprange),
364
)
365
# Divide positive values by repetition_penalty and multiply negative
366
# values by repetition_penalty (the academic publication that described
367
# this technique actually just only divided, but that would cause tokens
368
# with negative logits to become more likely, which is obviously wrong)
369
penalty_logits = jnp.where(
370
penalty_arange >= 0,
371
jnp.where(
372
penalty_logits > 0,
373
penalty_logits/repetition_penalty,
374
penalty_logits*repetition_penalty,
375
),
376
penalty_logits,
377
)
378
# Finally, put those penalized logit values back into their original
379
# positions in the logits array
380
return logits.at[tokens].set(penalty_logits)
381
382
def kobold_sample_static(key, logits, rpargs, sampler_order: Optional[np.ndarray] = None, top_p=0.9, temp=0.5, top_k=0, tfs=1.0, typical=1.0, top_a=0.0):
383
'''
384
This gets called by generate_loop_fn to apply a series of 6 filters
385
to the logits (top-k, then top-a, then top-p, then TFS, then typical, then temperature)
386
before picking one token using the modified logits
387
'''
388
# Top-k (keep only the k tokens with the highest logits and remove
389
# the rest, by setting their logits to negative infinity)
390
def top_k_filter(logits):
391
# After sorting the logits array in descending order,
392
# sorted_indices_to_remove is a 1D array that is True for tokens
393
# in the sorted logits array we want to remove and False for ones
394
# we want to keep, in this case the first top_k elements will be
395
# False and the rest will be True
396
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
397
# Unsort the logits array back to its original configuration and
398
# remove tokens we need to remove
399
_, indices_to_remove = jax.lax.sort_key_val(
400
jnp.argsort(-logits),
401
sorted_indices_to_remove,
402
)
403
return jnp.where(indices_to_remove, -jnp.inf, logits)
404
# Top-a (remove all tokens that have softmax probability less than
405
# a*m^2 where m is the maximum softmax probability)
406
def top_a_filter(logits):
407
# Replace every element in the logits array
408
# with e (Euler's number) to the power of that element, and divide
409
# each element of the new array by the sum of the elements in the
410
# new array
411
probabilities = jax.nn.softmax(logits)
412
# Find the largest probability
413
probs_max = probabilities.max()
414
# Remove tokens
415
return jnp.where(probabilities < probs_max * probs_max * top_a, -jnp.inf, logits)
416
# Top-p (after sorting the remaining tokens again in descending order of
417
# logit, remove the ones that have cumulative softmax probability
418
# greater than p)
419
def top_p_filter(logits):
420
# Sort the logits array in descending order, replace every element
421
# with e (Euler's number) to the power of that element, and divide
422
# each element of the new array by the sum of the elements in the
423
# new array
424
sorted_logits = -jnp.sort(-logits)
425
probabilities = jax.nn.softmax(sorted_logits)
426
# Calculate cumulative_probabilities as the prefix-sum array of
427
# probabilities
428
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
429
# We want to remove tokens with cumulative probability higher
430
# than top_p
431
sorted_indices_to_remove = cumulative_probabilities > top_p
432
# Don't ever remove the token with the highest logit, even if
433
# the probability is higher than top_p
434
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
435
# Unsort and remove
436
_, indices_to_remove = jax.lax.sort_key_val(
437
jnp.argsort(-logits),
438
sorted_indices_to_remove,
439
)
440
return jnp.where(indices_to_remove, -jnp.inf, logits)
441
# Tail free sampling (basically top-p a second time on remaining tokens
442
# except it's the "cumulative normalized absolute second finite
443
# differences of the softmax probabilities" instead of just the
444
# cumulative softmax probabilities)
445
def tail_free_filter(logits):
446
# Sort in descending order
447
sorted_logits = -jnp.sort(-logits)
448
# Softmax again
449
probabilities = jax.nn.softmax(sorted_logits)
450
# Calculate the second finite differences of that array (i.e.
451
# calculate the difference array and then calculate the difference
452
# array of the difference array)
453
d2 = jnp.diff(jnp.diff(probabilities))
454
# Get the absolute values of all those second finite differences
455
d2 = jnp.abs(d2)
456
# Normalize (all elements in the array are divided by the sum of the
457
# array's elements)
458
d2 = d2 / d2.sum(axis=-1, keepdims=True)
459
# Get the prefix-sum array
460
cumulative_d2 = jnp.cumsum(d2, axis=-1)
461
# We will remove the tokens with a cumulative normalized absolute
462
# second finite difference larger than the TFS value
463
sorted_indices_to_remove = cumulative_d2 > tfs
464
# Don't remove the token with the highest logit
465
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
466
# Since the d2 array has two fewer elements than the logits array,
467
# we'll add two extra Trues to the end
468
sorted_indices_to_remove = jnp.pad(
469
sorted_indices_to_remove,
470
(0, 2),
471
constant_values=True,
472
)
473
# Unsort and remove
474
_, indices_to_remove = jax.lax.sort_key_val(
475
jnp.argsort(-logits),
476
sorted_indices_to_remove,
477
)
478
return jnp.where(indices_to_remove, -jnp.inf, logits)
479
# Typical sampling (https://arxiv.org/pdf/2202.00666.pdf)
480
def typical_filter(logits):
481
# Compute softmax probabilities and the natural logarithms of them
482
probs = jax.nn.softmax(logits)
483
log_probs = jnp.log(probs)
484
# Compute the negative of entropy, which is the sum of p*ln(p) for all p
485
# in the set of softmax probabilities of the logits
486
neg_entropy = jnp.nansum(probs * log_probs, axis=-1, keepdims=True)
487
# Determine absolute difference between the negative entropy and the
488
# log probabilities
489
entropy_deviation = jnp.abs(neg_entropy - log_probs)
490
# Keep certain tokens such that the sum of the entropy_deviation of the
491
# kept tokens is the smallest possible value such that the sum of the
492
# softmax probabilities of the kept tokens is at least the threshold
493
# value (by sorting the tokens in ascending order of entropy_deviation
494
# and then keeping the smallest possible number of tokens from the
495
# beginning such that sum of softmax probabilities is at or above the
496
# threshold)
497
_, sorted_logits = jax.lax.sort_key_val(entropy_deviation, probs)
498
sorted_indices_to_remove = jnp.cumsum(sorted_logits, axis=-1) >= typical
499
sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=-1)
500
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
501
# Unsort and remove
502
_, indices_to_remove = jax.lax.sort_key_val(
503
jnp.argsort(entropy_deviation),
504
sorted_indices_to_remove,
505
)
506
return jnp.where(indices_to_remove, -jnp.inf, logits)
507
# Temperature (just divide the logits by the temperature)
508
def temp_filter(logits):
509
return logits / temp
510
for k in sampler_order:
511
logits = jax.lax.cond(jnp.logical_and(k == 0, top_k > 0), top_k_filter, lambda x: x, logits)
512
logits = jax.lax.cond(jnp.logical_and(k == 1, top_a > 0.0), top_a_filter, lambda x: x, logits)
513
logits = jax.lax.cond(jnp.logical_and(k == 2, top_p < 1.0), top_p_filter, lambda x: x, logits)
514
logits = jax.lax.cond(jnp.logical_and(k == 3, tfs < 1.0), tail_free_filter, lambda x: x, logits)
515
logits = jax.lax.cond(jnp.logical_and(k == 4, typical < 1.0), typical_filter, lambda x: x, logits)
516
logits = jax.lax.cond(jnp.logical_and(k == 5, temp != 1.0), temp_filter, lambda x: x, logits)
517
logits = jax.lax.cond(jnp.logical_and(k == 6, rpargs[1] != 1.0), lambda x: apply_repetition_penalty_static(*x), lambda x: x[0], (logits, *rpargs))
518
# Finally, pick one token using the softmax thingy again (it gives
519
# an array whose elements sum to 1 so it can be used nicely as a
520
# probability distribution)
521
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)
522
523
pad_token_id = 50256
524
525
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, generated_index, gen_length, rpslope, rprange, sampler_options):
526
numseqs = numseqs_aux.shape[0]
527
gi = data[0][1]
528
def sample_loop_fn(carry):
529
generated, generated_index, logits, _ = carry[0][0]
530
sample_key = carry[1]
531
# Get the pseudo-random number generator key that will
532
# be used by kobold_sample_dynamic to randomly pick a token
533
sample_key, new_key = jax.random.split(sample_key, num=2)
534
# Remove any tokens in the badwords list by setting
535
# their logits to negative infinity which effectively
536
# makes their probabilities of being chosen zero
537
logits[badwords] = -np.inf
538
# Use the sampler (kobold_sample_dynamic) to pick one token
539
# based on the logits array as a 0D uint32 array
540
# (higher logit means higher probability of being
541
# picked, non-linearly)
542
next_token = kobold_sample_dynamic(
543
sample_key,
544
logits,
545
(
546
generated,
547
repetition_penalty,
548
generated_index,
549
gen_length,
550
rpslope,
551
rprange,
552
),
553
**sampler_options,
554
)
555
# Remember what token was picked
556
generated[generated_index] = next_token
557
generated_index += 1
558
# Re-pack the current sample_loop_fn's state so we can
559
# get back the same variables the next time
560
carry[0][0] = [generated, generated_index, logits, next_token]
561
carry[0].append(carry[0].pop(0))
562
return carry[0], new_key
563
# return jax.lax.while_loop(
564
# lambda carry: carry[0][0][1] == gi,
565
# sample_loop_fn,
566
# (data, key),
567
# )
568
carry = (data, key)
569
while carry[0][0][1] == gi:
570
carry = sample_loop_fn(carry)
571
return carry
572
573
class PenalizingCausalTransformer(CausalTransformer):
574
def __init__(self, config, **kwargs):
575
# Initialize
576
super().__init__(config, **kwargs)
577
def generate_static(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
578
compiling_callback()
579
numseqs = numseqs_aux.shape[0]
580
# These are the tokens that we don't want the AI to ever write
581
badwords = jnp.array(vars.badwordsids).squeeze()
582
@hk.transform
583
def generate_sample(context, ctx_length):
584
# Give the initial context to the transformer
585
transformer = CausalTransformerShard(config)
586
def generate_initial_scan_fn(sequence_index, _):
587
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
588
# The "generated" array will contain the tokens from the
589
# context as well as the tokens picked by the sampler at
590
# each stage, padded with a bunch of 50256s, so we know
591
# which tokens have to be repetition penalized
592
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
593
generated_index = config["seq"]
594
# Add that information to generate_loop_fn's starting state
595
initial_state = (generated, generated_index, sequence_index) + initial_state
596
return sequence_index+1, initial_state
597
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
598
sample_key = initial_states[-1][0]
599
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
600
# Get repetition penalty from the arguments
601
repetition_penalty = sampler_options.pop('repetition_penalty', None)
602
rpslope = sampler_options.pop('rpslope', None)
603
rprange = sampler_options.pop('rprange', None)
604
# This is the main generation loop
605
def generate_loop_fn(carry):
606
# Unpack current generate_loop_fn state
607
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
608
sample_key = carry[1]
609
# Get the pseudo-random number generator key that will
610
# be used by kobold_sample_static to randomly pick a token
611
sample_key, new_key = jax.random.split(sample_key)
612
# Give the context to the model and get the logits it
613
# spits out
614
# (a 2D array with 1 row and 50400 columns representing
615
# how strongly it thinks each of the 50257 tokens in its
616
# vocabulary should be appended to the context, followed
617
# by 143 apparently useless columns ???)
618
logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings)
619
# Verify that logits does indeed have that many rows and
620
# columns (if you get an error here, pray for mercy)
621
assert logits.shape == (1, config["n_vocab"])
622
# Flatten it into a 1D array to make it easier to use
623
logits = logits[0]
624
# Remove any tokens in the badwords list by setting
625
# their logits to negative infinity which effectively
626
# makes their probabilities of being chosen zero
627
logits = logits.at[badwords].set(-jnp.inf)
628
# Use the sampler (kobold_sample_static) to pick one token
629
# based on the logits array as a 0D uint32 array
630
# (higher logit means higher probability of being
631
# picked, non-linearly)
632
next_token = kobold_sample_static(
633
sample_key,
634
logits,
635
(
636
generated,
637
repetition_penalty,
638
generated_index,
639
gen_length,
640
rpslope,
641
rprange,
642
),
643
**sampler_options,
644
)
645
# Remember what token was picked
646
generated = generated.at[generated_index].set(next_token)
647
generated_index += 1
648
# Re-pack the current generate_loop_fn's state so we can
649
# get back the same variables the next time
650
carry[0][0] = (generated, generated_index, sequence_index, next_token[jnp.newaxis], new_state)
651
carry[0].append(carry[0].pop(0))
652
return carry[0], new_key
653
return jax.lax.while_loop(
654
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
655
generate_loop_fn,
656
(initial_states, sample_key),
657
)
658
return generate_sample.apply(state["params"], key, ctx, ctx_length)
659
self.generate_static_xmap = jax.experimental.maps.xmap(
660
fun=generate_static,
661
in_axes=(
662
["shard", ...],
663
["batch", ...],
664
["batch", ...],
665
["batch", ...],
666
["batch", ...],
667
["batch", ...],
668
["batch", ...],
669
["shard", ...],
670
),
671
out_axes=["shard", "batch", ...],
672
axis_resources={'shard': 'mp', 'batch': 'dp'},
673
)
674
def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
675
compiling_callback()
676
numseqs = numseqs_aux.shape[0]
677
@hk.transform
678
def generate_initial_inner(context, ctx_length):
679
# Give the initial context to the transformer
680
transformer = CausalTransformerShard(config)
681
def generate_initial_scan_fn(sequence_index, c):
682
_, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings)
683
generated_index = config["seq"]
684
# Add that information to generate_loop_fn's starting state
685
initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
686
return sequence_index+1, initial_state
687
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs)
688
sample_key = initial_states[-1][0]
689
initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs))
690
return initial_states, sample_key
691
return generate_initial_inner.apply(state["params"], key, ctx, ctx_length)
692
self.generate_initial_xmap = jax.experimental.maps.xmap(
693
fun=generate_initial,
694
in_axes=(
695
["shard", ...],
696
["batch", ...],
697
["batch", ...],
698
["batch", ...],
699
["batch", ...],
700
["shard", ...],
701
),
702
out_axes=["shard", "batch", ...],
703
axis_resources={'shard': 'mp', 'batch': 'dp'},
704
)
705
def generate_once(data, state, numseqs_aux, soft_embeddings=None):
706
numseqs = numseqs_aux.shape[0]
707
@hk.without_apply_rng
708
@hk.transform
709
def generate_once_inner():
710
gi = data[0][1]
711
# Give the initial context to the transformer
712
transformer = CausalTransformerShard(config)
713
# This is the main generation loop
714
def generate_loop_fn(carry):
715
# Unpack current generate_loop_fn state
716
_, generated_index, sequence_index, next_token, decode_state = carry[0][0]
717
# Give the context to the model and get the logits it
718
# spits out
719
# (a 2D array with 1 row and 50400 columns representing
720
# how strongly it thinks each of the 50257 tokens in its
721
# vocabulary should be appended to the context, followed
722
# by 143 apparently useless columns ???)
723
logits, new_state = transformer.generate_once(next_token, decode_state, soft_embeddings=soft_embeddings)
724
# Verify that logits does indeed have that many rows and
725
# columns (if you get an error here, pray for mercy)
726
assert logits.shape == (1, config["n_vocab"])
727
assert logits.dtype == jnp.float32
728
# Flatten it into a 1D array to make it easier to use
729
logits = logits[0]
730
# Re-pack the current generate_loop_fn's state so we can
731
# get back the same variables the next time
732
generated_index += 1
733
carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state]
734
carry[0].append(carry[0].pop(0))
735
return carry[0],
736
return jax.lax.while_loop(
737
lambda carry: carry[0][0][1] == gi,
738
generate_loop_fn,
739
(data,),
740
)
741
return generate_once_inner.apply(state["params"])
742
self.generate_once_xmap = jax.experimental.maps.xmap(
743
fun=generate_once,
744
in_axes=(
745
["shard", "batch", ...],
746
["shard", ...],
747
["batch", ...],
748
["shard", ...],
749
),
750
out_axes=["shard", "batch", ...],
751
axis_resources={'shard': 'mp', 'batch': 'dp'},
752
)
753
def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
754
assert excluded_world_info is not None
755
assert not return_logits
756
assert gen_length.ndim == 1
757
assert soft_embeddings is not None
758
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
759
batch_size = ctx.shape[0]
760
self.batch_size = batch_size
761
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
762
numseqs_aux = batch_xmap(_numseqs_aux)
763
sample_data = [
764
[
765
np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id),
766
params["seq"],
767
None,
768
np.empty((), dtype=np.uint32),
769
]
770
for i in range(numseqs)
771
]
772
n_generated = 0
773
regeneration_required = False
774
halt = False
775
started_compiling_callback()
776
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
777
sample_key = np.asarray(sample_key[0, 0])
778
while True:
779
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
780
for i in range(numseqs):
781
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
782
if use_callback:
783
logits = np.float32(tuple(d[2] for d in sample_data))
784
logits = warper_callback(logits)
785
for i in range(numseqs):
786
sample_data[i][2] = logits[i]
787
sampler_options = settings_callback()
788
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
789
rpslope = sampler_options.pop("rpslope", 0.0)
790
rprange = sampler_options.pop("rprange", 0)
791
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, params["seq"] + n_generated, gen_length, rpslope, rprange, sampler_options)
792
n_generated += 1
793
for i in range(numseqs):
794
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
795
if use_callback:
796
generated = np.uint32(tuple(d[0] for d in sample_data))
797
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
798
if regeneration_required or halt:
799
break
800
else:
801
break
802
stopped_compiling_callback()
803
return sample_data, n_generated, regeneration_required, halt
804
def generate_static(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
805
assert not return_logits
806
key = hk.PRNGSequence(rng.randint(0, 2 ** 60))
807
batch_size = ctx.shape[0]
808
self.batch_size = batch_size
809
started_compiling_callback()
810
result = self.generate_static_xmap(
811
self.state,
812
jnp.array(key.take(batch_size)),
813
ctx,
814
np.array(ctx_length, dtype=np.uint32),
815
np.array(gen_length, dtype=np.uint32),
816
np.empty((batch_size, numseqs), dtype=np.uint8),
817
sampler_options,
818
soft_embeddings,
819
)
820
stopped_compiling_callback()
821
return result
822
823
824
def infer_dynamic(
825
context: np.array,
826
numseqs=1,
827
gen_len=80,
828
soft_embeddings: Optional[np.array] = None,
829
soft_tokens: Optional[np.array] = None,
830
excluded_world_info = None,
831
use_callback=True,
832
) -> Tuple[List[np.array], int, bool, bool]:
833
assert excluded_world_info is not None
834
maps.thread_resources.env = thread_resources_env
835
total_batch = 1
836
tokens = context
837
if(soft_tokens is not None):
838
tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1))
839
provided_ctx = tokens.shape[-1]
840
pad_amount = seq - provided_ctx
841
padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
842
batched_tokens = np.array([padded_tokens] * total_batch)
843
samples = []
844
output = network.generate_dynamic(
845
batched_tokens,
846
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
847
np.ones(total_batch, dtype=np.uint32) * gen_len,
848
numseqs,
849
soft_embeddings=soft_embeddings,
850
excluded_world_info=excluded_world_info,
851
use_callback=use_callback,
852
)
853
for out in output[0]:
854
samples.append(out[0][params["seq"] : params["seq"] + gen_len])
855
return (samples,) + output[1:]
856
857
def infer_static(
858
context: np.array,
859
top_p=0.9,
860
temp=0.5,
861
top_k=0,
862
tfs=1.0,
863
typical=1.0,
864
top_a=0.0,
865
repetition_penalty=1.0,
866
rpslope=0.0,
867
rprange=0,
868
numseqs=1,
869
gen_len=80,
870
soft_embeddings: Optional[np.array] = None,
871
soft_tokens: Optional[np.array] = None,
872
sampler_order: Optional[List[int]] = None,
873
) -> List[np.array]:
874
maps.thread_resources.env = thread_resources_env
875
if sampler_order is None:
876
sampler_order = utils.default_sampler_order.copy()
877
sampler_order = sampler_order[:]
878
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
879
sampler_order = [6] + sampler_order
880
sampler_order = np.uint32(sampler_order)
881
total_batch = 1
882
tokens = context
883
if(soft_tokens is not None):
884
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
885
provided_ctx = tokens.shape[0]
886
pad_amount = seq - provided_ctx
887
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
888
batched_tokens = np.array([padded_tokens] * total_batch)
889
samples = []
890
batched_generator_params = {
891
"sampler_order": np.repeat(sampler_order[np.newaxis], total_batch, axis=0),
892
"temp": temp * np.ones(total_batch),
893
"top_p": top_p * np.ones(total_batch),
894
"tfs": tfs * np.ones(total_batch),
895
"typical": typical * np.ones(total_batch),
896
"top_a": top_a * np.ones(total_batch),
897
"repetition_penalty": repetition_penalty * np.ones(total_batch),
898
"rpslope": rpslope * np.ones(total_batch),
899
"rprange": np.full(total_batch, rprange, dtype=np.uint32),
900
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
901
}
902
output = network.generate_static(
903
batched_tokens,
904
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
905
np.ones(total_batch, dtype=np.uint32) * gen_len,
906
numseqs,
907
batched_generator_params,
908
soft_embeddings=soft_embeddings,
909
)[0]
910
for o in output:
911
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len])
912
return samples
913
914
915
def reshard_reverse(x, total_shards, old_shape):
916
assert len(x.shape) != 1
917
if len(x.shape) == 2:
918
if old_shape[1] == x.shape[1]:
919
out = x[0:1].tile((total_shards, 1))
920
else:
921
out = x.reshape(old_shape)
922
elif len(x.shape) == 3:
923
if x.shape[0] * x.shape[2] == old_shape[2]:
924
out = x.reshape(old_shape)
925
elif x.shape[0] * x.shape[1] == old_shape[1]:
926
out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
927
else:
928
assert False
929
else:
930
assert False
931
return out
932
933
934
def get_old_shape(t, total_shards, dim=2):
935
if len(t.shape) == 2:
936
shard_shape = t.shape
937
if dim == 1:
938
assert shard_shape[0] % total_shards == 0
939
return (shard_shape[0] // total_shards, shard_shape[1])
940
elif dim == 2:
941
assert shard_shape[1] % total_shards == 0
942
return (shard_shape[0], shard_shape[1] // total_shards)
943
else:
944
raise ValueError(f"Unsupported dim {dim}")
945
if len(t.shape) == 1:
946
assert t.shape[0] % total_shards == 0
947
return (t.shape[0] // total_shards,)
948
else:
949
raise ValueError(f"Unsupported shape {t.shape}")
950
951
952
def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
953
assert config["cores_per_replica"] % checkpoint_shards == 0
954
output_shards = config["cores_per_replica"] // checkpoint_shards
955
956
import torch
957
import torch.utils.dlpack
958
import torch_lazy_loader
959
from tqdm.auto import tqdm
960
961
move_xmap = jax.experimental.maps.xmap(
962
fun=lambda x, _: to_bf16(x),
963
in_axes=(["shard", ...], ["batch", ...]),
964
out_axes=["shard", ...],
965
axis_resources={'shard': 'mp', 'batch': 'dp'}
966
)
967
968
path_template = os.path.join(path, "layer_{layer:02d}-model_{shard:02d}-model_states.pt")
969
970
static_mapping = {
971
"word_embeddings.weight": {"module": "embedding_shard/~/linear", "param": "w", "axis": 1},
972
"final_linear.weight": {"module": "projection_shard/~/linear", "param": "w", "axis": 2},
973
"norm.weight": {"module": "projection_shard/~/replicated_layer_norm", "param": "scale", "axis": None},
974
"norm.bias": {"module": "projection_shard/~/replicated_layer_norm", "param": "offset", "axis": None},
975
}
976
977
layer_mapping = {
978
"attention.query_key_value.weight": {"module": "combined_qkv", "param": "w", "axis": 2},
979
"attention.query_key_value.bias": {"module": "combined_qkv", "param": "b", "axis": 1},
980
"attention.dense.weight": {"module": "linear_3", "param": "w", "axis": 1},
981
"attention.dense.bias": {"module": "linear_3", "param": "b", "axis": None},
982
"mlp.dense_h_to_4h.weight": {"module": "linear_4", "param": "w", "axis": 2},
983
"mlp.dense_h_to_4h.bias": {"module": "linear_4", "param": "b", "axis": 1},
984
"mlp.dense_4h_to_h.weight": {"module": "linear_5", "param": "w", "axis": 1},
985
"mlp.dense_4h_to_h.bias": {"module": "linear_5", "param": "b", "axis": None},
986
"input_layernorm.weight": {"module": "replicated_layer_norm", "param": "scale", "axis": None},
987
"input_layernorm.bias": {"module": "replicated_layer_norm", "param": "offset", "axis": None},
988
"post_attention_layernorm.weight": {"module": "replicated_layer_norm_1", "param": "scale", "axis": None},
989
"post_attention_layernorm.bias": {"module": "replicated_layer_norm_1", "param": "offset", "axis": None},
990
}
991
992
tqdm_length = len(static_mapping) + config["layers"]*len(layer_mapping)
993
bar = tqdm(total=tqdm_length, desc="Loading from NeoX checkpoint")
994
995
for checkpoint_layer in range(config["layers"] + 5):
996
if checkpoint_layer in (1, config["layers"] + 2):
997
continue
998
layer = checkpoint_layer - 2
999
shards = []
1000
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
1001
for checkpoint_shard in range(checkpoint_shards):
1002
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
1003
for key in shards[0]:
1004
if key == "attention.rotary_emb.inv_freq":
1005
continue
1006
elif key in static_mapping:
1007
target_module = "causal_transformer_shard/~/" + static_mapping[key]["module"]
1008
target_param = static_mapping[key]["param"]
1009
target_axis = static_mapping[key]["axis"]
1010
elif key in layer_mapping:
1011
target_module = f"causal_transformer_shard/~/layer_{layer}/~/" + layer_mapping[key]["module"]
1012
target_param = layer_mapping[key]["param"]
1013
target_axis = layer_mapping[key]["axis"]
1014
else:
1015
error = f"{repr(key)} not found in mapping"
1016
print("\n\nERROR: ", error, file=sys.stderr)
1017
raise RuntimeError(error)
1018
original_shape = shards[0][key].shape
1019
for checkpoint_shard in range(checkpoint_shards):
1020
if key in ("attention.dense.bias", "mlp.dense_4h_to_h.bias"):
1021
shards[checkpoint_shard][key] /= output_shards
1022
if key != "word_embeddings.weight" and shards[checkpoint_shard][key].ndim == 2:
1023
shards[checkpoint_shard][key] = shards[checkpoint_shard][key].T
1024
tensor = shards[checkpoint_shard][key]
1025
if target_axis is not None:
1026
target_shape = (output_shards,) + get_old_shape(tensor, total_shards=output_shards, dim=target_axis)
1027
else:
1028
target_shape = (output_shards, tensor.shape[0])
1029
shards[checkpoint_shard][key] = reshard_reverse(tensor.unsqueeze_(0), output_shards, target_shape)
1030
#print(key, ":", original_shape, "->", shards[0][key].shape)
1031
tensor = torch.cat([shards[s][key] for s in range(checkpoint_shards)], dim=0)
1032
target_shape = state["params"][target_module][target_param].shape
1033
if tensor.shape != target_shape:
1034
error = f"Weight {repr(key)} has shape {tensor.shape} in checkpoint but shape {target_shape} was requested by MTJ for {target_module} {target_param}"
1035
print("\n\nERROR: ", error, file=sys.stderr)
1036
raise RuntimeError(error)
1037
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
1038
tensor = tensor.bfloat16()
1039
state["params"][target_module][target_param] = move_xmap(
1040
jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)).copy(),
1041
np.zeros(config["cores_per_replica"]),
1042
)
1043
bar.update(1)
1044
for mk, mv in state["params"].items():
1045
for pk, pv in mv.items():
1046
if isinstance(pv, PlaceholderTensor):
1047
error = f"{mk} {pk} could not be found in the model checkpoint"
1048
print("\n\nERROR: " + error, file=sys.stderr)
1049
raise RuntimeError(error)
1050
1051
1052
def load_model(path: str, driver_version="tpu_driver_20221109", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
1053
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
1054
1055
if "pad_token_id" in kwargs:
1056
pad_token_id = kwargs["pad_token_id"]
1057
elif "eos_token_id" in kwargs:
1058
pad_token_id = kwargs["eos_token_id"]
1059
1060
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
1061
vars.sampler_order = utils.default_sampler_order.copy()
1062
1063
default_params = {
1064
"compat": "j",
1065
"layers": 28,
1066
"d_model": 4096,
1067
"n_heads": 16,
1068
"n_vocab": 50400,
1069
"n_vocab_padding": 0,
1070
"norm": "layernorm",
1071
"pe": "rotary",
1072
"pe_rotary_dims": 64,
1073
"seq": 2048,
1074
"cores_per_replica": 8,
1075
"tokenizer_class": "GPT2Tokenizer",
1076
"tokenizer": "gpt2",
1077
}
1078
params = kwargs
1079
1080
if vars.model == "TPUMeshTransformerGPTNeoX":
1081
default_params = {
1082
"compat": "neox",
1083
"layers": 44,
1084
"d_model": 6144,
1085
"n_heads": 64,
1086
"n_vocab": 50432,
1087
"n_vocab_padding": 0,
1088
"norm": "doublelayernorm",
1089
"pe": "neox_rotary",
1090
"pe_rotary_dims": 24,
1091
"seq": 2048,
1092
"cores_per_replica": 8,
1093
"tokenizer_class": "GPT2Tokenizer",
1094
"tokenizer": "gpt2",
1095
}
1096
1097
# Try to convert HF config.json to MTJ config
1098
if hf_checkpoint:
1099
spec_path = os.path.join("maps", vars.model_type + ".json")
1100
if not os.path.isfile(spec_path):
1101
raise NotImplementedError(f"Unsupported model type {repr(vars.model_type)}")
1102
with open(spec_path) as f:
1103
lazy_load_spec = json.load(f)
1104
1105
if "mtj_compat" in lazy_load_spec:
1106
params["compat"] = lazy_load_spec["mtj_compat"]
1107
if "mtj_pe" in lazy_load_spec:
1108
params["pe"] = lazy_load_spec["mtj_pe"]
1109
for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
1110
if type(v) is not list:
1111
params[k] = params[v]
1112
continue
1113
for i in range(len(v)):
1114
if i == len(v) - 1:
1115
params[k] = v[i]
1116
elif v[i] in params:
1117
params[k] = params[v[i]]
1118
break
1119
1120
params["n_vocab"] = params["vocab_size"]
1121
1122
if "activation_function" in params:
1123
params["activation"] = params["activation_function"]
1124
1125
# Both the number of attention heads in the model and the embedding
1126
# dimension of the model need to be divisible by the number of TPU cores
1127
# that we use, and JAX also requires the number of TPU cores used to be
1128
# an even number if we're using more than one core, so logically we try
1129
# to pick the largest possible even number of TPU cores such that the
1130
# number of attention heads and embedding dimension are both divisible
1131
# by the number of TPU cores, and fall back to one core if an even
1132
# number of TPU cores is not possible.
1133
for c in (8, 6, 4, 2, 1):
1134
if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c:
1135
params["cores_per_replica"] = c
1136
break
1137
1138
# The vocabulary size of the model also has to be divisible by the
1139
# number of TPU cores, so we pad the vocabulary with the minimum
1140
# possible number of dummy tokens such that it's divisible.
1141
params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])
1142
1143
if "compat" in params:
1144
default_params["compat"] = params["compat"]
1145
if default_params["compat"] == "fairseq_lm":
1146
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
1147
for param in default_params:
1148
if param not in params:
1149
params[param] = default_params[param]
1150
1151
# Use an optimization that will allow us to avoid one extra transpose operation
1152
if hf_checkpoint:
1153
params["transposed_linear"] = True
1154
1155
# Load tokenizer
1156
if vars.model == "TPUMeshTransformerGPTNeoX":
1157
tokenizer = Tokenizer.from_file(os.path.join(path, "20B_tokenizer.json"))
1158
def new_encode(old_encode):
1159
def encode(s, *args, **kwargs):
1160
return old_encode(s).ids
1161
return encode
1162
tokenizer.encode = new_encode(tokenizer.encode)
1163
tokenizer._koboldai_header = []
1164
elif not hf_checkpoint:
1165
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
1166
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
1167
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
1168
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
1169
1170
# Disable JAX warnings about these two functions having been renamed
1171
jax.host_count = jax.process_count
1172
jax.host_id = jax.process_index
1173
1174
print("Connecting to your Colab instance's TPU", flush=True)
1175
spinner = multiprocessing.Process(target=show_spinner, args=())
1176
spinner.start()
1177
if os.environ.get('COLAB_TPU_ADDR', '') != '':
1178
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
1179
else:
1180
tpu_address = os.environ['TPU_NAME'] # Kaggle
1181
tpu_address = tpu_address.replace("grpc://", "")
1182
tpu_address_without_port = tpu_address.split(':', 1)[0]
1183
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
1184
requests.post(url)
1185
config.FLAGS.jax_xla_backend = "tpu_driver"
1186
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
1187
spinner.terminate()
1188
print()
1189
1190
cores_per_replica = params["cores_per_replica"]
1191
seq = params["seq"]
1192
params["optimizer"] = _DummyOptimizer()
1193
mesh_shape = (1, cores_per_replica)
1194
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
1195
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
1196
maps.thread_resources.env = thread_resources_env
1197
1198
global badwords
1199
# These are the tokens that we don't want the AI to ever write
1200
badwords = jnp.array(vars.badwordsids).squeeze()
1201
1202
if not path.endswith("/"):
1203
path += "/"
1204
1205
network = PenalizingCausalTransformer(params, dematerialized=True)
1206
1207
if not hf_checkpoint and vars.model != "TPUMeshTransformerGPTNeoX":
1208
network.state = read_ckpt_lowmem(network.state, path, devices.shape[1])
1209
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
1210
return
1211
1212
if vars.model == "TPUMeshTransformerGPTNeoX":
1213
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
1214
read_neox_checkpoint(network.state, path, params)
1215
return
1216
1217
# Convert from HF checkpoint
1218
1219
move_xmap = jax.experimental.maps.xmap(
1220
fun=lambda x, _: to_bf16(x),
1221
in_axes=(["shard", ...], ["batch", ...]),
1222
out_axes=["shard", ...],
1223
axis_resources={'shard': 'mp', 'batch': 'dp'}
1224
)
1225
1226
model_spec = {}
1227
for key, spec in lazy_load_spec.get("static_weights", {}).items():
1228
if spec.get("mtj") is not None:
1229
model_spec[key] = spec["mtj"].copy()
1230
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
1231
for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
1232
for layer in range(params["layers"]):
1233
if spec.get("mtj") is not None:
1234
key = _key.format(layer=layer)
1235
model_spec[key] = spec["mtj"].copy()
1236
model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)
1237
1238
import torch_lazy_loader
1239
import torch
1240
from tqdm.auto import tqdm
1241
import functools
1242
1243
1244
def callback(model_dict, f, **_):
1245
if callback.nested:
1246
return
1247
callback.nested = True
1248
with zipfile.ZipFile(f, "r") as z:
1249
try:
1250
last_storage_key = None
1251
zipfolder = os.path.basename(os.path.normpath(f)).split('.')[0]
1252
f = None
1253
current_offset = 0
1254
if utils.current_shard == 0:
1255
print("\n\n\nThis model has ", f"{hk.data_structures.tree_size(network.state['params']):,d}".replace(",", " "), " parameters.\n")
1256
1257
if utils.num_shards is None or utils.current_shard == 0:
1258
if utils.num_shards is not None:
1259
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
1260
else:
1261
num_tensors = len(model_dict)
1262
utils.bar = tqdm(total=num_tensors, desc="Loading model tensors")
1263
1264
if utils.num_shards is not None:
1265
utils.current_shard += 1
1266
for key in sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
1267
model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None)
1268
1269
# Some model weights are used by transformers but not by MTJ.
1270
# We have to materialize these weights anyways because
1271
# transformers will throw a tantrum otherwise. To attain
1272
# the least possible memory usage, we create them as meta
1273
# tensors, which don't take up any actual CPU or TPU memory.
1274
if model_spec_key is None:
1275
model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
1276
utils.bar.update(1)
1277
continue
1278
1279
storage_key = model_dict[key].key
1280
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
1281
last_storage_key = storage_key
1282
if isinstance(f, zipfile.ZipExtFile):
1283
f.close()
1284
try:
1285
f = z.open(f"archive/data/{storage_key}")
1286
except:
1287
f = z.open(f"{zipfolder}/data/{storage_key}")
1288
current_offset = 0
1289
if current_offset != model_dict[key].seek_offset:
1290
f.read(model_dict[key].seek_offset - current_offset)
1291
current_offset = model_dict[key].seek_offset
1292
spec = model_spec[model_spec_key]
1293
transforms = set(spec.get("transforms", ()))
1294
if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
1295
error = f"Duplicate key {repr(key)}"
1296
print("\n\nERROR: " + error, file=sys.stderr)
1297
raise RuntimeError(error)
1298
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
1299
dtype = model_dict[key].dtype
1300
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
1301
tensor = model_dict[key].materialize(f, map_location="cpu")
1302
model_dict[key] = tensor.to("meta")
1303
current_offset += nbytes
1304
1305
# MTJ requires certain mathematical operations to be performed
1306
# on tensors in order for them to be in the correct format
1307
if "remove_first_two_rows" in transforms:
1308
tensor = tensor[2:]
1309
if "divide_by_shards" in transforms:
1310
tensor /= params["cores_per_replica"]
1311
if "vocab_pad" in transforms:
1312
tensor = torch.nn.functional.pad(tensor, (0,) * (tensor.ndim * 2 - 1) + (params["n_vocab_padding"],))
1313
# We don't need to transpose linear module weights anymore because MTJ will do it for us if `transposed_linear` is set to True in the config
1314
#if "no_transpose" not in transforms and tensor.ndim == 2:
1315
# tensor = tensor.T
1316
tensor.unsqueeze_(0)
1317
1318
1319
# Shard the tensor so that parts of the tensor can be used
1320
# on different TPU cores
1321
tensor = reshard_reverse(
1322
tensor,
1323
params["cores_per_replica"],
1324
network.state["params"][spec["module"]][spec["param"]].shape,
1325
)
1326
tensor = jnp.array(tensor.detach())
1327
if tensor.dtype is torch.float16 or tensor.dtype is torch.float32:
1328
tensor = tensor.bfloat16()
1329
network.state["params"][spec["module"]][spec["param"]] = move_xmap(
1330
tensor,
1331
np.empty(params["cores_per_replica"]),
1332
)
1333
1334
utils.bar.update(1)
1335
1336
if utils.num_shards is not None and utils.current_shard < utils.num_shards:
1337
return
1338
1339
# Check for tensors that MTJ needs that were not provided in the
1340
# HF model
1341
for mk, mv in network.state["params"].items():
1342
for pk, pv in mv.items():
1343
if isinstance(pv, PlaceholderTensor):
1344
# The transformers GPT-J models apparently do not
1345
# have embedding bias, whereas MTJ GPT-J models do,
1346
# so we have to supplement an embedding bias tensor
1347
# by creating a tensor with the necessary shape, filled
1348
# with zeros.
1349
if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
1350
mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
1351
1352
else:
1353
error = f"{mk} {pk} could not be found in the model checkpoint"
1354
print("\n\nERROR: " + error, file=sys.stderr)
1355
raise RuntimeError(error)
1356
finally:
1357
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
1358
utils.bar.close()
1359
utils.bar = None
1360
callback.nested = False
1361
if isinstance(f, zipfile.ZipExtFile):
1362
f.close()
1363
callback.nested = False
1364
1365
if os.path.isdir(vars.model.replace('/', '_')):
1366
import shutil
1367
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
1368
print("\n", flush=True)
1369
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
1370
if(os.path.isdir(vars.custmodpth)):
1371
try:
1372
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache", use_fast=False)
1373
except Exception as e:
1374
try:
1375
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
1376
except Exception as e:
1377
try:
1378
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
1379
except Exception as e:
1380
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
1381
try:
1382
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
1383
except Exception as e:
1384
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
1385
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
1386
try:
1387
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache", use_fast=False)
1388
except Exception as e:
1389
try:
1390
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
1391
except Exception as e:
1392
try:
1393
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
1394
except Exception as e:
1395
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
1396
try:
1397
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
1398
except Exception as e:
1399
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
1400
else:
1401
try:
1402
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache", use_fast=False)
1403
except Exception as e:
1404
try:
1405
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
1406
except Exception as e:
1407
try:
1408
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
1409
except Exception as e:
1410
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
1411
try:
1412
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
1413
except Exception as e:
1414
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
1415
1416
#network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))
1417
global shard_xmap, batch_xmap
1418
shard_xmap = __shard_xmap()
1419
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
1420
1421