Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/src/attention.py
1003 views
1
"""
2
Originally forked from Andrej Karpathy's minGPT.
3
4
CS224N 2022-23: Homework 5
5
6
John Hewitt <[email protected]>
7
Ansh Khurana <[email protected]>
8
"""
9
10
import math
11
import logging
12
13
import torch
14
import torch.nn as nn
15
from torch.nn import functional as F
16
17
logger = logging.getLogger(__name__)
18
19
class CausalSelfAttention(nn.Module):
20
"""
21
A vanilla multi-head masked self-attention layer with a projection at the end.
22
I believe I could have just used torch.nn.MultiheadAttention but their documentation
23
is all but absent and code ugly so I don't trust it, rolling my own here.
24
"""
25
26
def __init__(self, config):
27
super().__init__()
28
assert config.n_embd % config.n_head == 0
29
# key, query, value projections for all heads
30
self.key = nn.Linear(config.n_embd, config.n_embd)
31
self.query = nn.Linear(config.n_embd, config.n_embd)
32
self.value = nn.Linear(config.n_embd, config.n_embd)
33
# regularization
34
self.attn_drop = nn.Dropout(config.attn_pdrop)
35
self.resid_drop = nn.Dropout(config.resid_pdrop)
36
# output projection
37
self.proj = nn.Linear(config.n_embd, config.n_embd)
38
# causal mask to ensure that attention is only applied to the left in the input sequence
39
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
40
.view(1, 1, config.block_size, config.block_size))
41
self.n_head = config.n_head
42
43
def forward(self, x):
44
B, T, C = x.size()
45
46
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
47
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
48
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
49
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
50
51
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
52
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
53
54
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, -1e10)
55
att = F.softmax(att, dim=-1)
56
att = self.attn_drop(att)
57
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
58
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
59
60
# output projection
61
y = self.resid_drop(self.proj(y))
62
return y
63
64
65
class CausalCrossAttention(nn.Module):
66
"""
67
Modifications over the self-attention layer to handle two inputs and perform
68
cross-attention between them.
69
This follows the implementation of the self attention module with
70
auto-regressive masking on (key).
71
Manipulation of batch-size to allow for different batch size between the
72
two inputs, with broadcasting over to the higher batch size value.
73
"""
74
75
def __init__(self, config):
76
super().__init__()
77
assert config.n_embd % config.n_head == 0
78
# key, query, value projections for all heads
79
self.key = nn.Linear(config.n_embd, config.n_embd)
80
self.query = nn.Linear(config.n_embd, config.n_embd)
81
self.value = nn.Linear(config.n_embd, config.n_embd)
82
# regularization
83
self.attn_drop = nn.Dropout(config.attn_pdrop)
84
self.resid_drop = nn.Dropout(config.resid_pdrop)
85
# output projection
86
self.proj = nn.Linear(config.n_embd, config.n_embd)
87
# causal mask to ensure that attention is only applied to the left in the input sequence
88
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
89
.view(1, 1, config.block_size, config.block_size))
90
self.n_head = config.n_head
91
92
def forward(self, x_kv, x_q):
93
Bk, Tk, Ck = x_kv.size()
94
Bq, Tq, Cq = x_q.size()
95
96
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
97
98
# keys of x1
99
k = self.key(x_kv).view(Bk, Tk, self.n_head, Ck // self.n_head).transpose(1, 2) # (B, nh, Tk, hs)
100
101
# query with x2
102
q = self.query(x_q).view(Bq, Tq, self.n_head, Cq // self.n_head).transpose(1, 2) # (B, nh, Tq, hs)
103
104
# values from x1
105
v = self.value(x_kv).view(Bk, Tk, self.n_head, Ck // self.n_head).transpose(1, 2) # (B, nh, Tk, hs)
106
107
# causal self-attention; (B, nh, Tk, hs) x (B, nh, hs, Tq) -> (B, nh, Tq, Tk)
108
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
109
110
B = max(Bk, Bq)
111
112
att = att.masked_fill(self.mask[:,:,:Tq,:Tk] == 0, -1e10)
113
att = F.softmax(att, dim=-1)
114
att = self.attn_drop(att)
115
y = att @ v # (B, nh, Tq, Tk) x (B, nh, Tk, hs) -> (B, nh, Tq, hs)
116
y = y.transpose(1, 2).contiguous().view(B, Tq, Cq) # re-assemble all head outputs side by side
117
118
# output projection
119
y = self.resid_drop(self.proj(y))
120
return y
121
122