Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
HJHGJGHHG
GitHub Repository: HJHGJGHHG/CCF-BDCI-AQYI
Path: blob/main/models/bert_tdlstm.py
153 views
1
import torch
2
import torch.nn as nn
3
from transformers import BertForSequenceClassification
4
from models.layers import DynamicLSTM
5
6
7
def get_x_l(character_in_text, all_hidden_states):
8
x_l_sum = torch.zeros(1, 64, 1024).to('cuda')
9
x_l_len = []
10
for i_sample in range(8): # batch_size
11
len_l = character_in_text[i_sample].tolist()[0] + 1
12
x_l_len.append(len_l)
13
x_l = torch.index_select(all_hidden_states[-1][i_sample], 0,
14
torch.tensor(list(range(1, len_l + 1))).to('cuda'))
15
zero_l = torch.zeros(64 - len_l, 1024).to('cuda')
16
x_l = torch.unsqueeze(torch.cat((x_l, zero_l), dim=0), dim=0)
17
x_l_sum = torch.cat((x_l_sum, x_l), dim=0)
18
x_l_len = torch.LongTensor(x_l_len).to('cuda')
19
x_l_sum = torch.index_select(x_l_sum, 0, torch.tensor(list(range(1, args.batch_size + 1))).to('cuda'))
20
return x_l_sum, x_l_len
21
22
23
class Bert_TD_LSTM(BertForSequenceClassification):
24
def __init__(self, config):
25
super().__init__(config)
26
self.bert.config.output_hidden_states = True
27
28
self.lstm_l = DynamicLSTM(1024, 1024, num_layers=1, batch_first=True)
29
self.lstm_r = DynamicLSTM(1024, 1024, num_layers=1, batch_first=True)
30
31
self.linear = nn.Sequential(
32
nn.Linear(3 * 1024, 1024),
33
nn.Tanh())
34
35
def forward(self,
36
input_ids=None,
37
attention_mask=None,
38
token_type_ids=None,
39
position_ids=None,
40
head_mask=None,
41
inputs_embeds=None,
42
labels=None,
43
output_attentions=None,
44
output_hidden_states=None,
45
return_dict=None,
46
character_in_text=None):
47
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
48
49
outputs = self.bert(input_ids,
50
attention_mask=attention_mask,
51
token_type_ids=token_type_ids,
52
position_ids=position_ids,
53
head_mask=head_mask,
54
inputs_embeds=inputs_embeds,
55
output_attentions=output_attentions,
56
output_hidden_states=output_hidden_states,
57
return_dict=return_dict)
58
59
all_hidden_states = outputs[2]
60
61