Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
HJHGJGHHG
GitHub Repository: HJHGJGHHG/CCF-BDCI-AQYI
Path: blob/main/models/layers.py
153 views
1
import math
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
6
7
class SqueezeEmbedding(nn.Module):
8
"""
9
Squeeze sequence embedding length to the longest one in the batch
10
"""
11
12
def __init__(self, batch_first=True):
13
super(SqueezeEmbedding, self).__init__()
14
self.batch_first = batch_first
15
16
def forward(self, x, x_len):
17
"""
18
sequence -> sort -> pad and pack -> unpack ->unsort
19
:param x: sequence embedding vectors
20
:param x_len: numpy/tensor list
21
:return:
22
"""
23
"""sort"""
24
x_sort_idx = torch.sort(-x_len)[1].long()
25
x_unsort_idx = torch.sort(x_sort_idx)[1].long()
26
x_len = x_len[x_sort_idx]
27
x = x[x_sort_idx]
28
"""pack"""
29
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len.cpu(), batch_first=self.batch_first)
30
"""unpack: out"""
31
out = torch.nn.utils.rnn.pad_packed_sequence(x_emb_p, batch_first=self.batch_first) # (sequence, lengths)
32
out = out[0] #
33
"""unsort"""
34
out = out[x_unsort_idx]
35
return out
36
37
38
class Attention(nn.Module):
39
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', dropout=0):
40
""" Attention Mechanism
41
:param embed_dim:
42
:param hidden_dim:
43
:param out_dim:
44
:param n_head: num of head (Multi-Head Attention)
45
:param score_function: scaled_dot_product / mlp (concat) / bi_linear (general dot)
46
:return (?, q_len, out_dim,)
47
"""
48
super(Attention, self).__init__()
49
if hidden_dim is None:
50
hidden_dim = embed_dim // n_head
51
if out_dim is None:
52
out_dim = embed_dim
53
self.embed_dim = embed_dim
54
self.hidden_dim = hidden_dim
55
self.n_head = n_head
56
self.score_function = score_function
57
self.w_k = nn.Linear(embed_dim, n_head * hidden_dim)
58
self.w_q = nn.Linear(embed_dim, n_head * hidden_dim)
59
self.proj = nn.Linear(n_head * hidden_dim, out_dim)
60
self.dropout = nn.Dropout(dropout)
61
if score_function == 'mlp':
62
self.weight = nn.Parameter(torch.Tensor(hidden_dim * 2))
63
elif self.score_function == 'bi_linear':
64
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
65
else: # dot_product / scaled_dot_product
66
self.register_parameter('weight', None)
67
self.reset_parameters()
68
69
def reset_parameters(self):
70
stdv = 1. / math.sqrt(self.hidden_dim)
71
if self.weight is not None:
72
self.weight.data.uniform_(-stdv, stdv)
73
74
def forward(self, k, q):
75
if len(q.shape) == 2: # q_len missing
76
q = torch.unsqueeze(q, dim=1)
77
if len(k.shape) == 2: # k_len missing
78
k = torch.unsqueeze(k, dim=1)
79
mb_size = k.shape[0] # ?
80
k_len = k.shape[1]
81
q_len = q.shape[1]
82
# k: (?, k_len, embed_dim,)
83
# q: (?, q_len, embed_dim,)
84
# kx: (n_head*?, k_len, hidden_dim)
85
# qx: (n_head*?, q_len, hidden_dim)
86
# score: (n_head*?, q_len, k_len,)
87
# output: (?, q_len, out_dim,)
88
kx = self.w_k(k).view(mb_size, k_len, self.n_head, self.hidden_dim)
89
kx = kx.permute(2, 0, 1, 3).contiguous().view(-1, k_len, self.hidden_dim)
90
qx = self.w_q(q).view(mb_size, q_len, self.n_head, self.hidden_dim)
91
qx = qx.permute(2, 0, 1, 3).contiguous().view(-1, q_len, self.hidden_dim)
92
if self.score_function == 'dot_product':
93
kt = kx.permute(0, 2, 1)
94
score = torch.bmm(qx, kt)
95
elif self.score_function == 'scaled_dot_product':
96
kt = kx.permute(0, 2, 1)
97
qkt = torch.bmm(qx, kt)
98
score = torch.div(qkt, math.sqrt(self.hidden_dim))
99
elif self.score_function == 'mlp':
100
kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
101
qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
102
kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2)
103
# kq = torch.unsqueeze(kx, dim=1) + torch.unsqueeze(qx, dim=2)
104
score = F.tanh(torch.matmul(kq, self.weight))
105
elif self.score_function == 'bi_linear':
106
qw = torch.matmul(qx, self.weight)
107
kt = kx.permute(0, 2, 1)
108
score = torch.bmm(qw, kt)
109
else:
110
raise RuntimeError('invalid score_function')
111
score = F.softmax(score, dim=-1)
112
output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim)
113
output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1) # (?, q_len, n_head*hidden_dim)
114
output = self.proj(output) # (?, q_len, out_dim)
115
output = self.dropout(output)
116
return output, score
117
118
119
class NoQueryAttention(Attention):
120
'''q is a parameter'''
121
122
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', q_len=1,
123
dropout=0):
124
super(NoQueryAttention, self).__init__(embed_dim, hidden_dim, out_dim, n_head, score_function, dropout)
125
self.q_len = q_len
126
self.q = nn.Parameter(torch.Tensor(q_len, embed_dim))
127
self.reset_q()
128
129
def reset_q(self):
130
stdv = 1. / math.sqrt(self.embed_dim)
131
self.q.data.uniform_(-stdv, stdv)
132
133
def forward(self, k, **kwargs):
134
mb_size = k.shape[0]
135
q = self.q.expand(mb_size, -1, -1)
136
return super(NoQueryAttention, self).forward(k, q)
137
138
139
class DynamicLSTM(nn.Module):
140
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
141
bidirectional=False, only_use_last_hidden_state=False, rnn_type='LSTM'):
142
"""
143
使LSTM支持一个batch中含有长度不同的句子,基于pack_padded_sequence方法实现
144
参考:https://zhuanlan.zhihu.com/p/34418001
145
"""
146
super(DynamicLSTM, self).__init__()
147
self.input_size = input_size
148
self.hidden_size = hidden_size
149
self.num_layers = num_layers
150
self.bias = bias
151
self.batch_first = batch_first
152
self.dropout = dropout
153
self.bidirectional = bidirectional
154
self.only_use_last_hidden_state = only_use_last_hidden_state
155
self.rnn_type = rnn_type
156
157
if self.rnn_type == 'LSTM':
158
self.RNN = nn.LSTM(
159
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
160
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
161
elif self.rnn_type == 'GRU':
162
self.RNN = nn.GRU(
163
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
164
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
165
elif self.rnn_type == 'RNN':
166
self.RNN = nn.RNN(
167
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
168
bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
169
170
def forward(self, x, x_len):
171
"""
172
sequence -> sort -> pad and pack ->process using RNN -> unpack ->unsort
173
174
:param x: sequence embedding vectors
175
:param x_len: numpy/tensor list
176
:return:
177
"""
178
"""sort"""
179
x_sort_idx = torch.sort(-x_len)[1].long()
180
x_unsort_idx = torch.sort(x_sort_idx)[1].long()
181
x_len = x_len[x_sort_idx].to('cpu')
182
x = x[x_sort_idx]
183
"""pack"""
184
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
185
186
# process using the selected RNN
187
if self.rnn_type == 'LSTM':
188
out_pack, (ht, ct) = self.RNN(x_emb_p, None)
189
else:
190
out_pack, ht = self.RNN(x_emb_p, None)
191
ct = None
192
"""unsort: h"""
193
ht = torch.transpose(ht, 0, 1)[
194
x_unsort_idx] # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
195
ht = torch.transpose(ht, 0, 1)
196
197
if self.only_use_last_hidden_state:
198
return ht
199
else:
200
"""unpack: out"""
201
out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first) # (sequence, lengths)
202
out = out[0] #
203
out = out[x_unsort_idx]
204
"""unsort: out c"""
205
if self.rnn_type == 'LSTM':
206
ct = torch.transpose(ct, 0, 1)[
207
x_unsort_idx] # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
208
ct = torch.transpose(ct, 0, 1)
209
210
return out, (ht, ct)
211
212