CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/longform-qa/lfqa_utils.py
Views: 2535
1
import functools
2
import math
3
import os # noqa: F401
4
from random import choice, randint
5
from time import time
6
7
import numpy as np
8
import torch
9
import torch.utils.checkpoint as checkpoint
10
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
11
from tqdm import tqdm
12
13
import faiss # noqa: F401
14
import nlp # noqa: F401
15
import pandas as pd
16
from elasticsearch import Elasticsearch # noqa: F401
17
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
18
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup
19
20
21
pd.set_option("display.max_colwidth", None)
22
23
24
###############
25
# Sparse index
26
###############
27
def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):
28
index_config = {
29
"settings": {
30
"number_of_shards": 1,
31
"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
32
},
33
"mappings": {
34
"properties": {
35
"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
36
"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
37
"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},
38
}
39
},
40
}
41
es_client.indices.create(index=index_name, body=index_config)
42
number_of_docs = passages_dset.num_rows
43
progress = tqdm(unit="docs", total=number_of_docs)
44
successes = 0
45
46
def passage_generator():
47
for passage in passages_dset:
48
yield passage
49
50
# create the ES index
51
for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):
52
progress.update(1)
53
successes += ok
54
print("Indexed %d documents" % (successes,))
55
56
57
def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):
58
q = question.lower()
59
banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]
60
q = " ".join([w for w in q.split() if w not in banned])
61
response = es_client.search(
62
index=index_name,
63
body={
64
"query": {
65
"multi_match": {
66
"query": q,
67
"fields": ["article_title", "section_title", "passage_text^2"],
68
"type": "cross_fields",
69
}
70
},
71
"size": 2 * n_results,
72
},
73
)
74
hits = response["hits"]["hits"]
75
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
76
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]
77
for r, hit in zip(res_list, hits):
78
r["passage_id"] = hit["_id"]
79
r["score"] = hit["_score"]
80
r["passage_text"] = hit["_source"]["passage_text"]
81
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
82
return support_doc, res_list
83
84
85
###############
86
# ELI5 retriever training
87
###############
88
class ELI5DatasetQARetriver(Dataset):
89
def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
90
self.data = examples_array
91
self.answer_thres = extra_answer_threshold
92
self.min_length = min_answer_length
93
self.training = training
94
self.n_samples = self.data.num_rows if n_samples is None else n_samples
95
96
def __len__(self):
97
return self.n_samples
98
99
def make_example(self, idx):
100
example = self.data[idx]
101
question = example["title"]
102
if self.training:
103
answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
104
answer_tab = choice(answers).split(" ")
105
start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
106
answer_span = " ".join(answer_tab[start_idx:])
107
else:
108
answer_span = example["answers"]["text"][0]
109
return (question, answer_span)
110
111
def __getitem__(self, idx):
112
return self.make_example(idx % self.data.num_rows)
113
114
115
class RetrievalQAEmbedder(torch.nn.Module):
116
def __init__(self, sent_encoder, dim):
117
super(RetrievalQAEmbedder, self).__init__()
118
self.sent_encoder = sent_encoder
119
self.output_dim = 128
120
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
121
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
122
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
123
124
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
125
# reproduces BERT forward pass with checkpointing
126
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
127
return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]
128
else:
129
# prepare implicit variables
130
device = input_ids.device
131
input_shape = input_ids.size()
132
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
133
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
134
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
135
attention_mask, input_shape, device
136
)
137
138
# define function for checkpointing
139
def partial_encode(*inputs):
140
encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
141
sequence_output = encoder_outputs[0]
142
pooled_output = self.sent_encoder.pooler(sequence_output)
143
return pooled_output
144
145
# run embedding layer on everything at once
146
embedding_output = self.sent_encoder.embeddings(
147
input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
148
)
149
# run encoding and pooling on one mini-batch at a time
150
pooled_output_list = []
151
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
152
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
153
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
154
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
155
pooled_output_list.append(pooled_output)
156
return torch.cat(pooled_output_list, dim=0)
157
158
def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
159
q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
160
return self.project_q(q_reps)
161
162
def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
163
a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
164
return self.project_a(a_reps)
165
166
def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
167
device = q_ids.device
168
q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
169
a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
170
compare_scores = torch.mm(q_reps, a_reps.t())
171
loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
172
loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
173
loss = (loss_qa + loss_aq) / 2
174
return loss
175
176
177
def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):
178
tokenizer = AutoTokenizer.from_pretrained(model_name)
179
bert_model = AutoModel.from_pretrained(model_name).to(device)
180
# run bert_model on a dummy batch to get output dimension
181
d_ids = torch.LongTensor(
182
[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]
183
).to(device)
184
d_mask = torch.LongTensor([[1]]).to(device)
185
sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]
186
qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)
187
if from_file is not None:
188
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
189
qa_embedder.load_state_dict(param_dict["model"])
190
return tokenizer, qa_embedder
191
192
193
def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):
194
q_ls = [q for q, a in qa_list]
195
a_ls = [a for q, a in qa_list]
196
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
197
q_ids, q_mask = (
198
torch.LongTensor(q_toks["input_ids"]).to(device),
199
torch.LongTensor(q_toks["attention_mask"]).to(device),
200
)
201
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)
202
a_ids, a_mask = (
203
torch.LongTensor(a_toks["input_ids"]).to(device),
204
torch.LongTensor(a_toks["attention_mask"]).to(device),
205
)
206
return (q_ids, q_mask, a_ids, a_mask)
207
208
209
def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):
210
model.train()
211
# make iterator
212
train_sampler = RandomSampler(dataset)
213
model_collate_fn = functools.partial(
214
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
215
)
216
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
217
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
218
# accumulate loss since last print
219
loc_steps = 0
220
loc_loss = 0.0
221
st_time = time()
222
for step, batch in enumerate(epoch_iterator):
223
q_ids, q_mask, a_ids, a_mask = batch
224
pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
225
loss = pre_loss.sum()
226
# optimizer
227
loss.backward()
228
optimizer.step()
229
scheduler.step()
230
model.zero_grad()
231
# some printing within the epoch
232
loc_loss += loss.item()
233
loc_steps += 1
234
if step % args.print_freq == 0 or step == 1:
235
print(
236
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
237
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
238
)
239
)
240
loc_loss = 0
241
loc_steps = 0
242
243
244
def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):
245
model.train()
246
model_collate_fn = functools.partial(
247
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
248
)
249
# make iterator
250
train_samplers = [RandomSampler(dataset) for dataset in dataset_list]
251
data_loaders = [
252
DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
253
for dataset, train_sampler in zip(dataset_list, train_samplers)
254
]
255
iterators = [iter(dloader) for dloader in data_loaders]
256
joint_iter = zip(*iterators)
257
# accumulate loss since last print
258
loc_steps = 0
259
loc_loss = 0.0
260
st_time = time()
261
for step, (batches,) in enumerate(zip(joint_iter)):
262
for batch in batches:
263
q_ids, q_mask, a_ids, a_mask = batch
264
loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
265
# optimizer
266
loss.backward()
267
optimizer.step()
268
scheduler.step()
269
model.zero_grad()
270
# some printing within the epoch
271
loc_loss += loss.item()
272
loc_steps += 1
273
if step % args.print_freq == 0:
274
print(
275
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
276
e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,
277
)
278
)
279
loc_loss = 0
280
loc_steps = 0
281
282
283
def evaluate_qa_retriever(model, dataset, tokenizer, args):
284
model.eval()
285
# make iterator
286
eval_sampler = SequentialSampler(dataset)
287
model_collate_fn = functools.partial(
288
make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
289
)
290
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)
291
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
292
tot_loss = 0.0
293
with torch.no_grad():
294
for step, batch in enumerate(epoch_iterator):
295
q_ids, q_mask, a_ids, a_mask = batch
296
loss = model(q_ids, q_mask, a_ids, a_mask)
297
tot_loss += loss.item()
298
return tot_loss / (step + 1)
299
300
301
def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):
302
qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)
303
qar_scheduler = get_linear_schedule_with_warmup(
304
qar_optimizer,
305
num_warmup_steps=100,
306
num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),
307
)
308
for e in range(qar_args.num_epochs):
309
train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)
310
m_save_dict = {
311
"model": qar_model.state_dict(),
312
"optimizer": qar_optimizer.state_dict(),
313
"scheduler": qar_scheduler.state_dict(),
314
}
315
print("Saving model {}".format(qar_args.model_save_name))
316
torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))
317
eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)
318
print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))
319
320
321
###############
322
# ELI5 seq2seq model training
323
###############
324
class ELI5DatasetS2S(Dataset):
325
def __init__(
326
self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
327
):
328
self.training = training
329
self.data = examples_array
330
self.make_doc_function = make_doc_fun
331
self.document_cache = {} if document_cache is None else document_cache
332
assert not (make_doc_fun is None and document_cache is None)
333
# make index of specific question-answer pairs from multi-answers
334
if self.training:
335
self.qa_id_list = [
336
(i, j)
337
for i, qa in enumerate(self.data)
338
for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
339
if j == 0 or sc >= extra_answer_threshold
340
]
341
else:
342
self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]
343
344
def __len__(self):
345
return len(self.qa_id_list)
346
347
def make_example(self, idx):
348
i, j = self.qa_id_list[idx]
349
example = self.data[i]
350
question = example["title"] + " " + example["selftext"]
351
answer = example["answers"]["text"][j]
352
q_id = example["q_id"]
353
if self.make_doc_function is not None:
354
self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
355
document = self.document_cache[q_id]
356
in_st = "question: {} context: {}".format(
357
question.lower().replace(" --t--", "").strip(), document.lower().strip(),
358
)
359
out_st = answer
360
return (in_st, out_st)
361
362
def __getitem__(self, idx):
363
return self.make_example(idx)
364
365
366
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
367
tokenizer = AutoTokenizer.from_pretrained(model_name)
368
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
369
if from_file is not None:
370
param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states
371
model.load_state_dict(param_dict["model"])
372
return tokenizer, model
373
374
375
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
376
q_ls = [q for q, a in qa_list]
377
a_ls = [a for q, a in qa_list]
378
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
379
q_ids, q_mask = (
380
torch.LongTensor(q_toks["input_ids"]).to(device),
381
torch.LongTensor(q_toks["attention_mask"]).to(device),
382
)
383
a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
384
a_ids, a_mask = (
385
torch.LongTensor(a_toks["input_ids"]).to(device),
386
torch.LongTensor(a_toks["attention_mask"]).to(device),
387
)
388
lm_labels = a_ids[:, 1:].contiguous().clone()
389
lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
390
model_inputs = {
391
"input_ids": q_ids,
392
"attention_mask": q_mask,
393
"decoder_input_ids": a_ids[:, :-1].contiguous(),
394
"lm_labels": lm_labels,
395
}
396
return model_inputs
397
398
399
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
400
model.train()
401
# make iterator
402
if curriculum:
403
train_sampler = SequentialSampler(dataset)
404
else:
405
train_sampler = RandomSampler(dataset)
406
model_collate_fn = functools.partial(
407
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
408
)
409
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
410
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
411
# accumulate loss since last print
412
loc_steps = 0
413
loc_loss = 0.0
414
st_time = time()
415
for step, batch_inputs in enumerate(epoch_iterator):
416
pre_loss = model(**batch_inputs)[0]
417
loss = pre_loss.sum() / pre_loss.shape[0]
418
loss.backward()
419
# optimizer
420
if step % args.backward_freq == 0:
421
optimizer.step()
422
scheduler.step()
423
model.zero_grad()
424
# some printing within the epoch
425
loc_loss += loss.item()
426
loc_steps += 1
427
if step % args.print_freq == 0 or step == 1:
428
print(
429
"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
430
e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
431
)
432
)
433
loc_loss = 0
434
loc_steps = 0
435
436
437
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
438
model.eval()
439
# make iterator
440
train_sampler = SequentialSampler(dataset)
441
model_collate_fn = functools.partial(
442
make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
443
)
444
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
445
epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
446
# accumulate loss since last print
447
loc_steps = 0
448
loc_loss = 0.0
449
st_time = time()
450
with torch.no_grad():
451
for step, batch_inputs in enumerate(epoch_iterator):
452
pre_loss = model(**batch_inputs)[0]
453
loss = pre_loss.sum() / pre_loss.shape[0]
454
loc_loss += loss.item()
455
loc_steps += 1
456
if step % args.print_freq == 0:
457
print(
458
"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
459
step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
460
)
461
)
462
print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))
463
464
465
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
466
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
467
s2s_scheduler = get_linear_schedule_with_warmup(
468
s2s_optimizer,
469
num_warmup_steps=400,
470
num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
471
)
472
for e in range(s2s_args.num_epochs):
473
train_qa_s2s_epoch(
474
qa_s2s_model,
475
s2s_train_dset,
476
qa_s2s_tokenizer,
477
s2s_optimizer,
478
s2s_scheduler,
479
s2s_args,
480
e,
481
curriculum=(e == 0),
482
)
483
m_save_dict = {
484
"model": qa_s2s_model.state_dict(),
485
"optimizer": s2s_optimizer.state_dict(),
486
"scheduler": s2s_scheduler.state_dict(),
487
}
488
print("Saving model {}".format(s2s_args.model_save_name))
489
eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
490
torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
491
492
493
# generate answer from input "question: ... context: <p> ..."
494
def qa_s2s_generate(
495
question_doc,
496
qa_s2s_model,
497
qa_s2s_tokenizer,
498
num_answers=1,
499
num_beams=None,
500
min_len=64,
501
max_len=256,
502
do_sample=False,
503
temp=1.0,
504
top_p=None,
505
top_k=None,
506
max_input_length=512,
507
device="cuda:0",
508
):
509
model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)
510
n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
511
generated_ids = qa_s2s_model.generate(
512
input_ids=model_inputs["input_ids"],
513
attention_mask=model_inputs["attention_mask"],
514
min_length=min_len,
515
max_length=max_len,
516
do_sample=do_sample,
517
early_stopping=True,
518
num_beams=1 if do_sample else n_beams,
519
temperature=temp,
520
top_k=top_k,
521
top_p=top_p,
522
eos_token_id=qa_s2s_tokenizer.eos_token_id,
523
no_repeat_ngram_size=3,
524
num_return_sequences=num_answers,
525
decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
526
)
527
return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]
528
529
530
###############
531
# ELI5-trained retrieval model usage
532
###############
533
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
534
a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
535
a_ids, a_mask = (
536
torch.LongTensor(a_toks["input_ids"]).to(device),
537
torch.LongTensor(a_toks["attention_mask"]).to(device),
538
)
539
with torch.no_grad():
540
a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
541
return a_reps.numpy()
542
543
544
def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
545
q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
546
q_ids, q_mask = (
547
torch.LongTensor(q_toks["input_ids"]).to(device),
548
torch.LongTensor(q_toks["attention_mask"]).to(device),
549
)
550
with torch.no_grad():
551
q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
552
return q_reps.numpy()
553
554
555
def make_qa_dense_index(
556
qa_embedder,
557
tokenizer,
558
passages_dset,
559
batch_size=512,
560
max_length=128,
561
index_name="kilt_passages_reps.dat",
562
dtype="float32",
563
device="cuda:0",
564
):
565
st_time = time()
566
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
567
n_batches = math.ceil(passages_dset.num_rows / batch_size)
568
for i in range(n_batches):
569
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
570
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
571
fp[i * batch_size : (i + 1) * batch_size] = reps
572
if i % 50 == 0:
573
print(i, time() - st_time)
574
575
576
def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
577
total_retriever_time = 0.0
578
total_retriever_score = 0.0
579
st_time = time()
580
for i, (question, answer) in enumerate(qa_list):
581
r_time = time()
582
retrieved_passages = retriever_func(question, n_ret)
583
total_retriever_time += time() - r_time
584
total_retriever_score += scoring_func(retrieved_passages, answer)
585
if verbose and ((i + 1) % 500 == 0 or i <= 1):
586
print(
587
"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
588
i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time
589
)
590
)
591
return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}
592
593
594
# build a support document for the question out of Wikipedia snippets
595
def query_qa_dense_index(
596
question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"
597
):
598
q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
599
D, I = wiki_index.search(q_rep, 2 * n_results)
600
res_passages = [wiki_passages[int(i)] for i in I[0]]
601
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
602
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
603
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
604
for r, sc in zip(res_list, D[0]):
605
r["score"] = float(sc)
606
return support_doc, res_list
607
608
609
def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
610
q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)
611
D, I = wiki_index.search(q_rep, n_results)
612
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
613
support_doc_lst = [
614
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
615
]
616
all_res_lists = []
617
for (res_passages, dl) in zip(res_passages_lst, D):
618
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
619
for r, sc in zip(res_list, dl):
620
r["score"] = float(sc)
621
all_res_lists += [res_list[:]]
622
return support_doc_lst, all_res_lists
623
624
625
# find nearest neighbors of an answer or declarative text in Wikipedia snippets
626
def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
627
a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
628
D, I = wiki_index.search(a_rep, 2 * n_results)
629
res_passages = [wiki_passages[int(i)] for i in I[0]]
630
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
631
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
632
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
633
for r, sc, i in zip(res_list, D[0], I[0]):
634
r["passage_id"] = int(i)
635
r["score"] = float(sc)
636
return support_doc, res_list
637
638
639
def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
640
a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
641
D, I = wiki_index.search(a_reps, n_results)
642
res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
643
support_doc_lst = [
644
"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
645
]
646
all_res_lists = []
647
for (res_passages, dl, il) in zip(res_passages_lst, D, I):
648
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
649
for r, sc, i in zip(res_list, dl, il):
650
r["passage_id"] = int(i)
651
r["score"] = float(sc)
652
all_res_lists += [res_list[:]]
653
return support_doc_lst, all_res_lists
654
655