Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/commons/common_layers.py
694 views
1
import math
2
import torch
3
from torch import nn
4
from torch.nn import Parameter
5
import torch.onnx.operators
6
import torch.nn.functional as F
7
import utils
8
9
10
class Reshape(nn.Module):
11
def __init__(self, *args):
12
super(Reshape, self).__init__()
13
self.shape = args
14
15
def forward(self, x):
16
return x.view(self.shape)
17
18
19
class Permute(nn.Module):
20
def __init__(self, *args):
21
super(Permute, self).__init__()
22
self.args = args
23
24
def forward(self, x):
25
return x.permute(self.args)
26
27
28
class LinearNorm(torch.nn.Module):
29
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
30
super(LinearNorm, self).__init__()
31
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
33
torch.nn.init.xavier_uniform_(
34
self.linear_layer.weight,
35
gain=torch.nn.init.calculate_gain(w_init_gain))
36
37
def forward(self, x):
38
return self.linear_layer(x)
39
40
41
class ConvNorm(torch.nn.Module):
42
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
43
padding=None, dilation=1, bias=True, w_init_gain='linear'):
44
super(ConvNorm, self).__init__()
45
if padding is None:
46
assert (kernel_size % 2 == 1)
47
padding = int(dilation * (kernel_size - 1) / 2)
48
49
self.conv = torch.nn.Conv1d(in_channels, out_channels,
50
kernel_size=kernel_size, stride=stride,
51
padding=padding, dilation=dilation,
52
bias=bias)
53
54
torch.nn.init.xavier_uniform_(
55
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56
57
def forward(self, signal):
58
conv_signal = self.conv(signal)
59
return conv_signal
60
61
62
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
63
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
64
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
65
if padding_idx is not None:
66
nn.init.constant_(m.weight[padding_idx], 0)
67
return m
68
69
70
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
71
if not export and torch.cuda.is_available():
72
try:
73
from apex.normalization import FusedLayerNorm
74
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
75
except ImportError:
76
pass
77
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78
79
80
def Linear(in_features, out_features, bias=True):
81
m = nn.Linear(in_features, out_features, bias)
82
nn.init.xavier_uniform_(m.weight)
83
if bias:
84
nn.init.constant_(m.bias, 0.)
85
return m
86
87
88
class SinusoidalPositionalEmbedding(nn.Module):
89
"""This module produces sinusoidal positional embeddings of any length.
90
91
Padding symbols are ignored.
92
"""
93
94
def __init__(self, embedding_dim, padding_idx, init_size=1024):
95
super().__init__()
96
self.embedding_dim = embedding_dim
97
self.padding_idx = padding_idx
98
self.weights = SinusoidalPositionalEmbedding.get_embedding(
99
init_size,
100
embedding_dim,
101
padding_idx,
102
)
103
self.register_buffer('_float_tensor', torch.FloatTensor(1))
104
105
@staticmethod
106
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
107
"""Build sinusoidal embeddings.
108
109
This matches the implementation in tensor2tensor, but differs slightly
110
from the description in Section 3.5 of "Attention Is All You Need".
111
"""
112
half_dim = embedding_dim // 2
113
emb = math.log(10000) / (half_dim - 1)
114
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
115
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
116
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
117
if embedding_dim % 2 == 1:
118
# zero pad
119
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
120
if padding_idx is not None:
121
emb[padding_idx, :] = 0
122
return emb
123
124
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
125
"""Input is expected to be of size [bsz x seqlen]."""
126
bsz, seq_len = input.shape[:2]
127
max_pos = self.padding_idx + 1 + seq_len
128
if self.weights is None or max_pos > self.weights.size(0):
129
# recompute/expand embeddings if needed
130
self.weights = SinusoidalPositionalEmbedding.get_embedding(
131
max_pos,
132
self.embedding_dim,
133
self.padding_idx,
134
)
135
self.weights = self.weights.to(self._float_tensor)
136
137
if incremental_state is not None:
138
# positions is the same for every token when decoding a single step
139
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
140
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
141
142
positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
143
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
144
145
def max_positions(self):
146
"""Maximum number of supported positions."""
147
return int(1e5) # an arbitrary large number
148
149
150
class ConvTBC(nn.Module):
151
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
152
super(ConvTBC, self).__init__()
153
self.in_channels = in_channels
154
self.out_channels = out_channels
155
self.kernel_size = kernel_size
156
self.padding = padding
157
158
self.weight = torch.nn.Parameter(torch.Tensor(
159
self.kernel_size, in_channels, out_channels))
160
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
161
162
def forward(self, input):
163
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
164
165
166
class MultiheadAttention(nn.Module):
167
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
168
add_bias_kv=False, add_zero_attn=False, self_attention=False,
169
encoder_decoder_attention=False):
170
super().__init__()
171
self.embed_dim = embed_dim
172
self.kdim = kdim if kdim is not None else embed_dim
173
self.vdim = vdim if vdim is not None else embed_dim
174
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
175
176
self.num_heads = num_heads
177
self.dropout = dropout
178
self.head_dim = embed_dim // num_heads
179
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
180
self.scaling = self.head_dim ** -0.5
181
182
self.self_attention = self_attention
183
self.encoder_decoder_attention = encoder_decoder_attention
184
185
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
186
'value to be of the same size'
187
188
if self.qkv_same_dim:
189
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
190
else:
191
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
192
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
193
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
194
195
if bias:
196
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
197
else:
198
self.register_parameter('in_proj_bias', None)
199
200
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
201
202
if add_bias_kv:
203
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
204
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
205
else:
206
self.bias_k = self.bias_v = None
207
208
self.add_zero_attn = add_zero_attn
209
210
self.reset_parameters()
211
212
self.enable_torch_version = False
213
if hasattr(F, "multi_head_attention_forward"):
214
self.enable_torch_version = True
215
else:
216
self.enable_torch_version = False
217
self.last_attn_probs = None
218
219
def reset_parameters(self):
220
if self.qkv_same_dim:
221
nn.init.xavier_uniform_(self.in_proj_weight)
222
else:
223
nn.init.xavier_uniform_(self.k_proj_weight)
224
nn.init.xavier_uniform_(self.v_proj_weight)
225
nn.init.xavier_uniform_(self.q_proj_weight)
226
227
nn.init.xavier_uniform_(self.out_proj.weight)
228
if self.in_proj_bias is not None:
229
nn.init.constant_(self.in_proj_bias, 0.)
230
nn.init.constant_(self.out_proj.bias, 0.)
231
if self.bias_k is not None:
232
nn.init.xavier_normal_(self.bias_k)
233
if self.bias_v is not None:
234
nn.init.xavier_normal_(self.bias_v)
235
236
def forward(
237
self,
238
query, key, value,
239
key_padding_mask=None,
240
incremental_state=None,
241
need_weights=True,
242
static_kv=False,
243
attn_mask=None,
244
before_softmax=False,
245
need_head_weights=False,
246
enc_dec_attn_constraint_mask=None,
247
reset_attn_weight=None
248
):
249
"""Input shape: Time x Batch x Channel
250
251
Args:
252
key_padding_mask (ByteTensor, optional): mask to exclude
253
keys that are pads, of shape `(batch, src_len)`, where
254
padding elements are indicated by 1s.
255
need_weights (bool, optional): return the attention weights,
256
averaged over heads (default: False).
257
attn_mask (ByteTensor, optional): typically used to
258
implement causal attention, where the mask prevents the
259
attention from looking forward in time (default: None).
260
before_softmax (bool, optional): return the raw attention
261
weights and values before the attention softmax.
262
need_head_weights (bool, optional): return the attention
263
weights for each head. Implies *need_weights*. Default:
264
return the average attention weights over all heads.
265
"""
266
if need_head_weights:
267
need_weights = True
268
269
tgt_len, bsz, embed_dim = query.size()
270
assert embed_dim == self.embed_dim
271
assert list(query.size()) == [tgt_len, bsz, embed_dim]
272
273
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
274
if self.qkv_same_dim:
275
return F.multi_head_attention_forward(query, key, value,
276
self.embed_dim, self.num_heads,
277
self.in_proj_weight,
278
self.in_proj_bias, self.bias_k, self.bias_v,
279
self.add_zero_attn, self.dropout,
280
self.out_proj.weight, self.out_proj.bias,
281
self.training, key_padding_mask, need_weights,
282
attn_mask)
283
else:
284
return F.multi_head_attention_forward(query, key, value,
285
self.embed_dim, self.num_heads,
286
torch.empty([0]),
287
self.in_proj_bias, self.bias_k, self.bias_v,
288
self.add_zero_attn, self.dropout,
289
self.out_proj.weight, self.out_proj.bias,
290
self.training, key_padding_mask, need_weights,
291
attn_mask, use_separate_proj_weight=True,
292
q_proj_weight=self.q_proj_weight,
293
k_proj_weight=self.k_proj_weight,
294
v_proj_weight=self.v_proj_weight)
295
296
if incremental_state is not None:
297
print('Not implemented error.')
298
exit()
299
else:
300
saved_state = None
301
302
if self.self_attention:
303
# self-attention
304
q, k, v = self.in_proj_qkv(query)
305
elif self.encoder_decoder_attention:
306
# encoder-decoder attention
307
q = self.in_proj_q(query)
308
if key is None:
309
assert value is None
310
k = v = None
311
else:
312
k = self.in_proj_k(key)
313
v = self.in_proj_v(key)
314
315
else:
316
q = self.in_proj_q(query)
317
k = self.in_proj_k(key)
318
v = self.in_proj_v(value)
319
q *= self.scaling
320
321
if self.bias_k is not None:
322
assert self.bias_v is not None
323
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
324
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
325
if attn_mask is not None:
326
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
327
if key_padding_mask is not None:
328
key_padding_mask = torch.cat(
329
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
330
331
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
332
if k is not None:
333
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
334
if v is not None:
335
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
336
337
if saved_state is not None:
338
print('Not implemented error.')
339
exit()
340
341
src_len = k.size(1)
342
343
# This is part of a workaround to get around fork/join parallelism
344
# not supporting Optional types.
345
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
346
key_padding_mask = None
347
348
if key_padding_mask is not None:
349
assert key_padding_mask.size(0) == bsz
350
assert key_padding_mask.size(1) == src_len
351
352
if self.add_zero_attn:
353
src_len += 1
354
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
355
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
356
if attn_mask is not None:
357
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
358
if key_padding_mask is not None:
359
key_padding_mask = torch.cat(
360
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
361
362
attn_weights = torch.bmm(q, k.transpose(1, 2))
363
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
364
365
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
366
367
if attn_mask is not None:
368
if len(attn_mask.shape) == 2:
369
attn_mask = attn_mask.unsqueeze(0)
370
elif len(attn_mask.shape) == 3:
371
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
372
bsz * self.num_heads, tgt_len, src_len)
373
attn_weights = attn_weights + attn_mask
374
375
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
376
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
377
attn_weights = attn_weights.masked_fill(
378
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
379
-1e9,
380
)
381
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
382
383
if key_padding_mask is not None:
384
# don't attend to padding symbols
385
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
attn_weights = attn_weights.masked_fill(
387
key_padding_mask.unsqueeze(1).unsqueeze(2),
388
-1e9,
389
)
390
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
391
392
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
393
394
if before_softmax:
395
return attn_weights, v
396
397
attn_weights_float = utils.softmax(attn_weights, dim=-1)
398
attn_weights = attn_weights_float.type_as(attn_weights)
399
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
400
401
if reset_attn_weight is not None:
402
if reset_attn_weight:
403
self.last_attn_probs = attn_probs.detach()
404
else:
405
assert self.last_attn_probs is not None
406
attn_probs = self.last_attn_probs
407
attn = torch.bmm(attn_probs, v)
408
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
409
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
410
attn = self.out_proj(attn)
411
412
if need_weights:
413
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
414
if not need_head_weights:
415
# average attention weights over heads
416
attn_weights = attn_weights.mean(dim=0)
417
else:
418
attn_weights = None
419
420
return attn, (attn_weights, attn_logits)
421
422
def in_proj_qkv(self, query):
423
return self._in_proj(query).chunk(3, dim=-1)
424
425
def in_proj_q(self, query):
426
if self.qkv_same_dim:
427
return self._in_proj(query, end=self.embed_dim)
428
else:
429
bias = self.in_proj_bias
430
if bias is not None:
431
bias = bias[:self.embed_dim]
432
return F.linear(query, self.q_proj_weight, bias)
433
434
def in_proj_k(self, key):
435
if self.qkv_same_dim:
436
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
437
else:
438
weight = self.k_proj_weight
439
bias = self.in_proj_bias
440
if bias is not None:
441
bias = bias[self.embed_dim:2 * self.embed_dim]
442
return F.linear(key, weight, bias)
443
444
def in_proj_v(self, value):
445
if self.qkv_same_dim:
446
return self._in_proj(value, start=2 * self.embed_dim)
447
else:
448
weight = self.v_proj_weight
449
bias = self.in_proj_bias
450
if bias is not None:
451
bias = bias[2 * self.embed_dim:]
452
return F.linear(value, weight, bias)
453
454
def _in_proj(self, input, start=0, end=None):
455
weight = self.in_proj_weight
456
bias = self.in_proj_bias
457
weight = weight[start:end, :]
458
if bias is not None:
459
bias = bias[start:end]
460
return F.linear(input, weight, bias)
461
462
463
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
464
return attn_weights
465
466
467
class Swish(torch.autograd.Function):
468
@staticmethod
469
def forward(ctx, i):
470
result = i * torch.sigmoid(i)
471
ctx.save_for_backward(i)
472
return result
473
474
@staticmethod
475
def backward(ctx, grad_output):
476
i = ctx.saved_variables[0]
477
sigmoid_i = torch.sigmoid(i)
478
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
479
480
481
class CustomSwish(nn.Module):
482
def forward(self, input_tensor):
483
return Swish.apply(input_tensor)
484
485
class Mish(nn.Module):
486
def forward(self, x):
487
return x * torch.tanh(F.softplus(x))
488
489
class TransformerFFNLayer(nn.Module):
490
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
491
super().__init__()
492
self.kernel_size = kernel_size
493
self.dropout = dropout
494
self.act = act
495
if padding == 'SAME':
496
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
497
elif padding == 'LEFT':
498
self.ffn_1 = nn.Sequential(
499
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
500
nn.Conv1d(hidden_size, filter_size, kernel_size)
501
)
502
self.ffn_2 = Linear(filter_size, hidden_size)
503
if self.act == 'swish':
504
self.swish_fn = CustomSwish()
505
506
def forward(self, x, incremental_state=None):
507
# x: T x B x C
508
if incremental_state is not None:
509
assert incremental_state is None, 'Nar-generation does not allow this.'
510
exit(1)
511
512
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
513
x = x * self.kernel_size ** -0.5
514
515
if incremental_state is not None:
516
x = x[-1:]
517
if self.act == 'gelu':
518
x = F.gelu(x)
519
if self.act == 'relu':
520
x = F.relu(x)
521
if self.act == 'swish':
522
x = self.swish_fn(x)
523
x = F.dropout(x, self.dropout, training=self.training)
524
x = self.ffn_2(x)
525
return x
526
527
528
class BatchNorm1dTBC(nn.Module):
529
def __init__(self, c):
530
super(BatchNorm1dTBC, self).__init__()
531
self.bn = nn.BatchNorm1d(c)
532
533
def forward(self, x):
534
"""
535
536
:param x: [T, B, C]
537
:return: [T, B, C]
538
"""
539
x = x.permute(1, 2, 0) # [B, C, T]
540
x = self.bn(x) # [B, C, T]
541
x = x.permute(2, 0, 1) # [T, B, C]
542
return x
543
544
545
class EncSALayer(nn.Module):
546
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
547
relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
548
super().__init__()
549
self.c = c
550
self.dropout = dropout
551
self.num_heads = num_heads
552
if num_heads > 0:
553
if norm == 'ln':
554
self.layer_norm1 = LayerNorm(c)
555
elif norm == 'bn':
556
self.layer_norm1 = BatchNorm1dTBC(c)
557
self.self_attn = MultiheadAttention(
558
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
559
)
560
if norm == 'ln':
561
self.layer_norm2 = LayerNorm(c)
562
elif norm == 'bn':
563
self.layer_norm2 = BatchNorm1dTBC(c)
564
self.ffn = TransformerFFNLayer(
565
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
566
567
def forward(self, x, encoder_padding_mask=None, **kwargs):
568
layer_norm_training = kwargs.get('layer_norm_training', None)
569
if layer_norm_training is not None:
570
self.layer_norm1.training = layer_norm_training
571
self.layer_norm2.training = layer_norm_training
572
if self.num_heads > 0:
573
residual = x
574
x = self.layer_norm1(x)
575
x, _, = self.self_attn(
576
query=x,
577
key=x,
578
value=x,
579
key_padding_mask=encoder_padding_mask
580
)
581
x = F.dropout(x, self.dropout, training=self.training)
582
x = residual + x
583
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
584
585
residual = x
586
x = self.layer_norm2(x)
587
x = self.ffn(x)
588
x = F.dropout(x, self.dropout, training=self.training)
589
x = residual + x
590
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
591
return x
592
593
594
class DecSALayer(nn.Module):
595
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
596
super().__init__()
597
self.c = c
598
self.dropout = dropout
599
self.layer_norm1 = LayerNorm(c)
600
self.self_attn = MultiheadAttention(
601
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
602
)
603
self.layer_norm2 = LayerNorm(c)
604
self.encoder_attn = MultiheadAttention(
605
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
606
)
607
self.layer_norm3 = LayerNorm(c)
608
self.ffn = TransformerFFNLayer(
609
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
610
611
def forward(
612
self,
613
x,
614
encoder_out=None,
615
encoder_padding_mask=None,
616
incremental_state=None,
617
self_attn_mask=None,
618
self_attn_padding_mask=None,
619
attn_out=None,
620
reset_attn_weight=None,
621
**kwargs,
622
):
623
layer_norm_training = kwargs.get('layer_norm_training', None)
624
if layer_norm_training is not None:
625
self.layer_norm1.training = layer_norm_training
626
self.layer_norm2.training = layer_norm_training
627
self.layer_norm3.training = layer_norm_training
628
residual = x
629
x = self.layer_norm1(x)
630
x, _ = self.self_attn(
631
query=x,
632
key=x,
633
value=x,
634
key_padding_mask=self_attn_padding_mask,
635
incremental_state=incremental_state,
636
attn_mask=self_attn_mask
637
)
638
x = F.dropout(x, self.dropout, training=self.training)
639
x = residual + x
640
641
residual = x
642
x = self.layer_norm2(x)
643
if encoder_out is not None:
644
x, attn = self.encoder_attn(
645
query=x,
646
key=encoder_out,
647
value=encoder_out,
648
key_padding_mask=encoder_padding_mask,
649
incremental_state=incremental_state,
650
static_kv=True,
651
enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
652
reset_attn_weight=reset_attn_weight
653
)
654
attn_logits = attn[1]
655
else:
656
assert attn_out is not None
657
x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
658
attn_logits = None
659
x = F.dropout(x, self.dropout, training=self.training)
660
x = residual + x
661
662
residual = x
663
x = self.layer_norm3(x)
664
x = self.ffn(x, incremental_state=incremental_state)
665
x = F.dropout(x, self.dropout, training=self.training)
666
x = residual + x
667
# if len(attn_logits.size()) > 3:
668
# indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
669
# attn_logits = attn_logits.gather(1,
670
# indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
671
return x, attn_logits
672
673