Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/sanity_check.py
984 views
1
import torch
2
from bert import BertModel
3
sanity_data = torch.load("./sanity_check.data")
4
# text_batch = ["hello world", "hello neural network for NLP"]
5
# tokenizer here
6
sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
7
[101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])
8
att_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1]])
9
10
# load our model
11
bert = BertModel.from_pretrained('bert-base-uncased')
12
outputs = bert(sent_ids, att_mask)
13
att_mask = att_mask.unsqueeze(-1)
14
outputs['last_hidden_state'] = outputs['last_hidden_state'] * att_mask
15
sanity_data['last_hidden_state'] = sanity_data['last_hidden_state'] * att_mask
16
17
for k in ['last_hidden_state', 'pooler_output']:
18
assert torch.allclose(outputs[k], sanity_data[k], atol=1e-5, rtol=1e-3)
19
print("Your BERT implementation is correct!")
20
21
22