"""
Originally forked from Andrej Karpathy's minGPT.
CS224N 2022-23: Homework 5
John Hewitt <[email protected]>
Ansh Khurana <[email protected]>
"""
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
I believe I could have just used torch.nn.MultiheadAttention but their documentation
is all but absent and code ugly so I don't trust it, rolling my own here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x):
B, T, C = x.size()
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, -1e10)
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_drop(self.proj(y))
return y
class CausalCrossAttention(nn.Module):
"""
Modifications over the self-attention layer to handle two inputs and perform
cross-attention between them.
This follows the implementation of the self attention module with
auto-regressive masking on (key).
Manipulation of batch-size to allow for different batch size between the
two inputs, with broadcasting over to the higher batch size value.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x_kv, x_q):
Bk, Tk, Ck = x_kv.size()
Bq, Tq, Cq = x_q.size()
k = self.key(x_kv).view(Bk, Tk, self.n_head, Ck // self.n_head).transpose(1, 2)
q = self.query(x_q).view(Bq, Tq, self.n_head, Cq // self.n_head).transpose(1, 2)
v = self.value(x_kv).view(Bk, Tk, self.n_head, Ck // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
B = max(Bk, Bq)
att = att.masked_fill(self.mask[:,:,:Tq,:Tk] == 0, -1e10)
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, Tq, Cq)
y = self.resid_drop(self.proj(y))
return y