import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SqueezeEmbedding(nn.Module):
"""
Squeeze sequence embedding length to the longest one in the batch
"""
def __init__(self, batch_first=True):
super(SqueezeEmbedding, self).__init__()
self.batch_first = batch_first
def forward(self, x, x_len):
"""
sequence -> sort -> pad and pack -> unpack ->unsort
:param x: sequence embedding vectors
:param x_len: numpy/tensor list
:return:
"""
"""sort"""
x_sort_idx = torch.sort(-x_len)[1].long()
x_unsort_idx = torch.sort(x_sort_idx)[1].long()
x_len = x_len[x_sort_idx]
x = x[x_sort_idx]
"""pack"""
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len.cpu(), batch_first=self.batch_first)
"""unpack: out"""
out = torch.nn.utils.rnn.pad_packed_sequence(x_emb_p, batch_first=self.batch_first)
out = out[0]
"""unsort"""
out = out[x_unsort_idx]
return out
class Attention(nn.Module):
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', dropout=0):
""" Attention Mechanism
:param embed_dim:
:param hidden_dim:
:param out_dim:
:param n_head: num of head (Multi-Head Attention)
:param score_function: scaled_dot_product / mlp (concat) / bi_linear (general dot)
:return (?, q_len, out_dim,)
"""
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = embed_dim // n_head
if out_dim is None:
out_dim = embed_dim
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.n_head = n_head
self.score_function = score_function
self.w_k = nn.Linear(embed_dim, n_head * hidden_dim)
self.w_q = nn.Linear(embed_dim, n_head * hidden_dim)
self.proj = nn.Linear(n_head * hidden_dim, out_dim)
self.dropout = nn.Dropout(dropout)
if score_function == 'mlp':
self.weight = nn.Parameter(torch.Tensor(hidden_dim * 2))
elif self.score_function == 'bi_linear':
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.hidden_dim)
if self.weight is not None:
self.weight.data.uniform_(-stdv, stdv)
def forward(self, k, q):
if len(q.shape) == 2:
q = torch.unsqueeze(q, dim=1)
if len(k.shape) == 2:
k = torch.unsqueeze(k, dim=1)
mb_size = k.shape[0]
k_len = k.shape[1]
q_len = q.shape[1]
kx = self.w_k(k).view(mb_size, k_len, self.n_head, self.hidden_dim)
kx = kx.permute(2, 0, 1, 3).contiguous().view(-1, k_len, self.hidden_dim)
qx = self.w_q(q).view(mb_size, q_len, self.n_head, self.hidden_dim)
qx = qx.permute(2, 0, 1, 3).contiguous().view(-1, q_len, self.hidden_dim)
if self.score_function == 'dot_product':
kt = kx.permute(0, 2, 1)
score = torch.bmm(qx, kt)
elif self.score_function == 'scaled_dot_product':
kt = kx.permute(0, 2, 1)
qkt = torch.bmm(qx, kt)
score = torch.div(qkt, math.sqrt(self.hidden_dim))
elif self.score_function == 'mlp':
kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
kq = torch.cat((kxx, qxx), dim=-1)
score = F.tanh(torch.matmul(kq, self.weight))
elif self.score_function == 'bi_linear':
qw = torch.matmul(qx, self.weight)
kt = kx.permute(0, 2, 1)
score = torch.bmm(qw, kt)
else:
raise RuntimeError('invalid score_function')
score = F.softmax(score, dim=-1)
output = torch.bmm(score, kx)
output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1)
output = self.proj(output)
output = self.dropout(output)
return output, score
class NoQueryAttention(Attention):
'''q is a parameter'''
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', q_len=1,
dropout=0):
super(NoQueryAttention, self).__init__(embed_dim, hidden_dim, out_dim, n_head, score_function, dropout)
self.q_len = q_len
self.q = nn.Parameter(torch.Tensor(q_len, embed_dim))
self.reset_q()
def reset_q(self):
stdv = 1. / math.sqrt(self.embed_dim)
self.q.data.uniform_(-stdv, stdv)
def forward(self, k, **kwargs):
mb_size = k.shape[0]
q = self.q.expand(mb_size, -1, -1)
return super(NoQueryAttention, self).forward(k, q)
class DynamicLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
bidirectional=False, only_use_last_hidden_state=False, rnn_type='LSTM'):
"""
使LSTM支持一个batch中含有长度不同的句子,基于pack_padded_sequence方法实现
参考:https://zhuanlan.zhihu.com/p/34418001
"""
super(DynamicLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.bidirectional = bidirectional
self.only_use_last_hidden_state = only_use_last_hidden_state
self.rnn_type = rnn_type
if self.rnn_type == 'LSTM':
self.RNN = nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
elif self.rnn_type == 'GRU':
self.RNN = nn.GRU(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
elif self.rnn_type == 'RNN':
self.RNN = nn.RNN(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
def forward(self, x, x_len):
"""
sequence -> sort -> pad and pack ->process using RNN -> unpack ->unsort
:param x: sequence embedding vectors
:param x_len: numpy/tensor list
:return:
"""
"""sort"""
x_sort_idx = torch.sort(-x_len)[1].long()
x_unsort_idx = torch.sort(x_sort_idx)[1].long()
x_len = x_len[x_sort_idx].to('cpu')
x = x[x_sort_idx]
"""pack"""
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
if self.rnn_type == 'LSTM':
out_pack, (ht, ct) = self.RNN(x_emb_p, None)
else:
out_pack, ht = self.RNN(x_emb_p, None)
ct = None
"""unsort: h"""
ht = torch.transpose(ht, 0, 1)[
x_unsort_idx]
ht = torch.transpose(ht, 0, 1)
if self.only_use_last_hidden_state:
return ht
else:
"""unpack: out"""
out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first)
out = out[0]
out = out[x_unsort_idx]
"""unsort: out c"""
if self.rnn_type == 'LSTM':
ct = torch.transpose(ct, 0, 1)[
x_unsort_idx]
ct = torch.transpose(ct, 0, 1)
return out, (ht, ct)