Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
HJHGJGHHG
GitHub Repository: HJHGJGHHG/CCF-BDCI-AQYI
Path: blob/main/models/Model.py
153 views
1
import torch
2
import torch.nn as nn
3
from transformers import BertForSequenceClassification
4
5
6
class BertForMultilabelSequenceClassification(BertForSequenceClassification):
7
def __init__(self, config):
8
super().__init__(config)
9
self.bert.config.output_hidden_states = True
10
11
self.linear = nn.Sequential(
12
nn.Linear(3 * 1024, 1024),
13
nn.Tanh())
14
15
def forward(self,
16
input_ids=None,
17
attention_mask=None,
18
token_type_ids=None,
19
position_ids=None,
20
head_mask=None,
21
inputs_embeds=None,
22
labels=None,
23
output_attentions=None,
24
output_hidden_states=None,
25
return_dict=None):
26
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27
28
outputs = self.bert(input_ids,
29
attention_mask=attention_mask,
30
token_type_ids=token_type_ids,
31
position_ids=position_ids,
32
head_mask=head_mask,
33
inputs_embeds=inputs_embeds,
34
output_attentions=output_attentions,
35
output_hidden_states=output_hidden_states,
36
return_dict=return_dict)
37
38
"""
39
pooled_output = outputs[1]
40
pooled_output = self.dropout(pooled_output)
41
logits = self.classifier(pooled_output)
42
"""
43
44
all_hidden_states = torch.stack(outputs[2])
45
concatenate_pooling = torch.cat(
46
(torch.squeeze(torch.index_select(all_hidden_states[-1], 1, torch.tensor([0]).to('cuda'))),
47
torch.squeeze(torch.index_select(all_hidden_states[-2], 1, torch.tensor([0]).to('cuda'))),
48
outputs[1]
49
),
50
-1)
51
logits = self.linear(concatenate_pooling)
52
logits = self.dropout(logits)
53
logits = self.classifier(logits)
54
55
loss = None
56
if labels is not None:
57
loss_fct = nn.BCEWithLogitsLoss()
58
loss = loss_fct(logits.view(-1, self.num_labels),
59
labels.float().view(-1, self.num_labels))
60
61
return loss, logits
62
63