Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
HJHGJGHHG
GitHub Repository: HJHGJGHHG/CCF-BDCI-AQYI
Path: blob/main/Model.py
153 views
1
import torch
2
import torch.nn as nn
3
from transformers import BertForSequenceClassification
4
5
6
class BertForMultilabelSequenceClassification(BertForSequenceClassification):
7
def __init__(self, config):
8
super().__init__(config)
9
self.bert.config.output_hidden_states = True
10
11
self.linear = nn.Sequential(
12
nn.Linear(3 * 1024, 1024),
13
nn.Tanh())
14
15
def forward(self,
16
input_ids=None,
17
attention_mask=None,
18
token_type_ids=None,
19
position_ids=None,
20
head_mask=None,
21
inputs_embeds=None,
22
labels=None,
23
output_attentions=None,
24
output_hidden_states=None,
25
return_dict=None):
26
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27
28
outputs = self.bert(input_ids,
29
attention_mask=attention_mask,
30
token_type_ids=token_type_ids,
31
position_ids=position_ids,
32
head_mask=head_mask,
33
inputs_embeds=inputs_embeds,
34
output_attentions=output_attentions,
35
output_hidden_states=output_hidden_states,
36
return_dict=return_dict)
37
38
# pooled_output = outputs[1]
39
# pooled_output = self.dropout(pooled_output)
40
# logits = self.classifier(pooled_output)
41
42
all_hidden_states = torch.stack(outputs[2])
43
concatenate_pooling = torch.cat(
44
(torch.index_select(all_hidden_states[-1], 1, torch.tensor([0]).to('cuda')),
45
torch.index_select(all_hidden_states[-2], 1, torch.tensor([0]).to('cuda')),
46
torch.index_select(all_hidden_states[-3], 1, torch.tensor([0]).to('cuda'))),
47
-1)
48
logits = self.linear(concatenate_pooling)
49
logits = self.dropout(logits)
50
logits = self.classifier(logits)
51
52
loss = None
53
if labels is not None:
54
loss_fct = nn.BCEWithLogitsLoss()
55
loss = loss_fct(logits.view(-1, self.num_labels),
56
labels.float().view(-1, self.num_labels))
57
58
return loss, logits
59
60