Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/longform-qa/lfqa_utils.py
Views: 2535
import functools1import math2import os # noqa: F4013from random import choice, randint4from time import time56import numpy as np7import torch8import torch.utils.checkpoint as checkpoint9from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler10from tqdm import tqdm1112import faiss # noqa: F40113import nlp # noqa: F40114import pandas as pd15from elasticsearch import Elasticsearch # noqa: F40116from elasticsearch.helpers import bulk, streaming_bulk # noqa: F40117from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup181920pd.set_option("display.max_colwidth", None)212223###############24# Sparse index25###############26def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"):27index_config = {28"settings": {29"number_of_shards": 1,30"analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},31},32"mappings": {33"properties": {34"article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},35"section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"},36"passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"},37}38},39}40es_client.indices.create(index=index_name, body=index_config)41number_of_docs = passages_dset.num_rows42progress = tqdm(unit="docs", total=number_of_docs)43successes = 04445def passage_generator():46for passage in passages_dset:47yield passage4849# create the ES index50for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),):51progress.update(1)52successes += ok53print("Indexed %d documents" % (successes,))545556def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20):57q = question.lower()58banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"]59q = " ".join([w for w in q.split() if w not in banned])60response = es_client.search(61index=index_name,62body={63"query": {64"multi_match": {65"query": q,66"fields": ["article_title", "section_title", "passage_text^2"],67"type": "cross_fields",68}69},70"size": 2 * n_results,71},72)73hits = response["hits"]["hits"]74support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])75res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits]76for r, hit in zip(res_list, hits):77r["passage_id"] = hit["_id"]78r["score"] = hit["_score"]79r["passage_text"] = hit["_source"]["passage_text"]80res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]81return support_doc, res_list828384###############85# ELI5 retriever training86###############87class ELI5DatasetQARetriver(Dataset):88def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):89self.data = examples_array90self.answer_thres = extra_answer_threshold91self.min_length = min_answer_length92self.training = training93self.n_samples = self.data.num_rows if n_samples is None else n_samples9495def __len__(self):96return self.n_samples9798def make_example(self, idx):99example = self.data[idx]100question = example["title"]101if self.training:102answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]103answer_tab = choice(answers).split(" ")104start_idx = randint(0, max(0, len(answer_tab) - self.min_length))105answer_span = " ".join(answer_tab[start_idx:])106else:107answer_span = example["answers"]["text"][0]108return (question, answer_span)109110def __getitem__(self, idx):111return self.make_example(idx % self.data.num_rows)112113114class RetrievalQAEmbedder(torch.nn.Module):115def __init__(self, sent_encoder, dim):116super(RetrievalQAEmbedder, self).__init__()117self.sent_encoder = sent_encoder118self.output_dim = 128119self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)120self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)121self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")122123def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):124# reproduces BERT forward pass with checkpointing125if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:126return self.sent_encoder(input_ids, attention_mask=attention_mask)[1]127else:128# prepare implicit variables129device = input_ids.device130input_shape = input_ids.size()131token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)132head_mask = [None] * self.sent_encoder.config.num_hidden_layers133extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(134attention_mask, input_shape, device135)136137# define function for checkpointing138def partial_encode(*inputs):139encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)140sequence_output = encoder_outputs[0]141pooled_output = self.sent_encoder.pooler(sequence_output)142return pooled_output143144# run embedding layer on everything at once145embedding_output = self.sent_encoder.embeddings(146input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None147)148# run encoding and pooling on one mini-batch at a time149pooled_output_list = []150for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):151b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]152b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]153pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)154pooled_output_list.append(pooled_output)155return torch.cat(pooled_output_list, dim=0)156157def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):158q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)159return self.project_q(q_reps)160161def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):162a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)163return self.project_a(a_reps)164165def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):166device = q_ids.device167q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)168a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)169compare_scores = torch.mm(q_reps, a_reps.t())170loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))171loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))172loss = (loss_qa + loss_aq) / 2173return loss174175176def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"):177tokenizer = AutoTokenizer.from_pretrained(model_name)178bert_model = AutoModel.from_pretrained(model_name).to(device)179# run bert_model on a dummy batch to get output dimension180d_ids = torch.LongTensor(181[[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]]182).to(device)183d_mask = torch.LongTensor([[1]]).to(device)184sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1]185qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device)186if from_file is not None:187param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states188qa_embedder.load_state_dict(param_dict["model"])189return tokenizer, qa_embedder190191192def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"):193q_ls = [q for q, a in qa_list]194a_ls = [a for q, a in qa_list]195q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)196q_ids, q_mask = (197torch.LongTensor(q_toks["input_ids"]).to(device),198torch.LongTensor(q_toks["attention_mask"]).to(device),199)200a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True)201a_ids, a_mask = (202torch.LongTensor(a_toks["input_ids"]).to(device),203torch.LongTensor(a_toks["attention_mask"]).to(device),204)205return (q_ids, q_mask, a_ids, a_mask)206207208def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0):209model.train()210# make iterator211train_sampler = RandomSampler(dataset)212model_collate_fn = functools.partial(213make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"214)215data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)216epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)217# accumulate loss since last print218loc_steps = 0219loc_loss = 0.0220st_time = time()221for step, batch in enumerate(epoch_iterator):222q_ids, q_mask, a_ids, a_mask = batch223pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)224loss = pre_loss.sum()225# optimizer226loss.backward()227optimizer.step()228scheduler.step()229model.zero_grad()230# some printing within the epoch231loc_loss += loss.item()232loc_steps += 1233if step % args.print_freq == 0 or step == 1:234print(235"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(236e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,237)238)239loc_loss = 0240loc_steps = 0241242243def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0):244model.train()245model_collate_fn = functools.partial(246make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"247)248# make iterator249train_samplers = [RandomSampler(dataset) for dataset in dataset_list]250data_loaders = [251DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)252for dataset, train_sampler in zip(dataset_list, train_samplers)253]254iterators = [iter(dloader) for dloader in data_loaders]255joint_iter = zip(*iterators)256# accumulate loss since last print257loc_steps = 0258loc_loss = 0.0259st_time = time()260for step, (batches,) in enumerate(zip(joint_iter)):261for batch in batches:262q_ids, q_mask, a_ids, a_mask = batch263loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)264# optimizer265loss.backward()266optimizer.step()267scheduler.step()268model.zero_grad()269# some printing within the epoch270loc_loss += loss.item()271loc_steps += 1272if step % args.print_freq == 0:273print(274"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(275e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time,276)277)278loc_loss = 0279loc_steps = 0280281282def evaluate_qa_retriever(model, dataset, tokenizer, args):283model.eval()284# make iterator285eval_sampler = SequentialSampler(dataset)286model_collate_fn = functools.partial(287make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"288)289data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn)290epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)291tot_loss = 0.0292with torch.no_grad():293for step, batch in enumerate(epoch_iterator):294q_ids, q_mask, a_ids, a_mask = batch295loss = model(q_ids, q_mask, a_ids, a_mask)296tot_loss += loss.item()297return tot_loss / (step + 1)298299300def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args):301qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8)302qar_scheduler = get_linear_schedule_with_warmup(303qar_optimizer,304num_warmup_steps=100,305num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size),306)307for e in range(qar_args.num_epochs):308train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e)309m_save_dict = {310"model": qar_model.state_dict(),311"optimizer": qar_optimizer.state_dict(),312"scheduler": qar_scheduler.state_dict(),313}314print("Saving model {}".format(qar_args.model_save_name))315torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e))316eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args)317print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss))318319320###############321# ELI5 seq2seq model training322###############323class ELI5DatasetS2S(Dataset):324def __init__(325self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True326):327self.training = training328self.data = examples_array329self.make_doc_function = make_doc_fun330self.document_cache = {} if document_cache is None else document_cache331assert not (make_doc_fun is None and document_cache is None)332# make index of specific question-answer pairs from multi-answers333if self.training:334self.qa_id_list = [335(i, j)336for i, qa in enumerate(self.data)337for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))338if j == 0 or sc >= extra_answer_threshold339]340else:341self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]342343def __len__(self):344return len(self.qa_id_list)345346def make_example(self, idx):347i, j = self.qa_id_list[idx]348example = self.data[i]349question = example["title"] + " " + example["selftext"]350answer = example["answers"]["text"][j]351q_id = example["q_id"]352if self.make_doc_function is not None:353self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))354document = self.document_cache[q_id]355in_st = "question: {} context: {}".format(356question.lower().replace(" --t--", "").strip(), document.lower().strip(),357)358out_st = answer359return (in_st, out_st)360361def __getitem__(self, idx):362return self.make_example(idx)363364365def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):366tokenizer = AutoTokenizer.from_pretrained(model_name)367model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)368if from_file is not None:369param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states370model.load_state_dict(param_dict["model"])371return tokenizer, model372373374def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):375q_ls = [q for q, a in qa_list]376a_ls = [a for q, a in qa_list]377q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)378q_ids, q_mask = (379torch.LongTensor(q_toks["input_ids"]).to(device),380torch.LongTensor(q_toks["attention_mask"]).to(device),381)382a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)383a_ids, a_mask = (384torch.LongTensor(a_toks["input_ids"]).to(device),385torch.LongTensor(a_toks["attention_mask"]).to(device),386)387lm_labels = a_ids[:, 1:].contiguous().clone()388lm_labels[a_mask[:, 1:].contiguous() == 0] = -100389model_inputs = {390"input_ids": q_ids,391"attention_mask": q_mask,392"decoder_input_ids": a_ids[:, :-1].contiguous(),393"lm_labels": lm_labels,394}395return model_inputs396397398def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):399model.train()400# make iterator401if curriculum:402train_sampler = SequentialSampler(dataset)403else:404train_sampler = RandomSampler(dataset)405model_collate_fn = functools.partial(406make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"407)408data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)409epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)410# accumulate loss since last print411loc_steps = 0412loc_loss = 0.0413st_time = time()414for step, batch_inputs in enumerate(epoch_iterator):415pre_loss = model(**batch_inputs)[0]416loss = pre_loss.sum() / pre_loss.shape[0]417loss.backward()418# optimizer419if step % args.backward_freq == 0:420optimizer.step()421scheduler.step()422model.zero_grad()423# some printing within the epoch424loc_loss += loss.item()425loc_steps += 1426if step % args.print_freq == 0 or step == 1:427print(428"{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(429e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,430)431)432loc_loss = 0433loc_steps = 0434435436def eval_qa_s2s_epoch(model, dataset, tokenizer, args):437model.eval()438# make iterator439train_sampler = SequentialSampler(dataset)440model_collate_fn = functools.partial(441make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"442)443data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)444epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)445# accumulate loss since last print446loc_steps = 0447loc_loss = 0.0448st_time = time()449with torch.no_grad():450for step, batch_inputs in enumerate(epoch_iterator):451pre_loss = model(**batch_inputs)[0]452loss = pre_loss.sum() / pre_loss.shape[0]453loc_loss += loss.item()454loc_steps += 1455if step % args.print_freq == 0:456print(457"{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(458step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,459)460)461print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,))462463464def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):465s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)466s2s_scheduler = get_linear_schedule_with_warmup(467s2s_optimizer,468num_warmup_steps=400,469num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),470)471for e in range(s2s_args.num_epochs):472train_qa_s2s_epoch(473qa_s2s_model,474s2s_train_dset,475qa_s2s_tokenizer,476s2s_optimizer,477s2s_scheduler,478s2s_args,479e,480curriculum=(e == 0),481)482m_save_dict = {483"model": qa_s2s_model.state_dict(),484"optimizer": s2s_optimizer.state_dict(),485"scheduler": s2s_scheduler.state_dict(),486}487print("Saving model {}".format(s2s_args.model_save_name))488eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)489torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))490491492# generate answer from input "question: ... context: <p> ..."493def qa_s2s_generate(494question_doc,495qa_s2s_model,496qa_s2s_tokenizer,497num_answers=1,498num_beams=None,499min_len=64,500max_len=256,501do_sample=False,502temp=1.0,503top_p=None,504top_k=None,505max_input_length=512,506device="cuda:0",507):508model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,)509n_beams = num_answers if num_beams is None else max(num_beams, num_answers)510generated_ids = qa_s2s_model.generate(511input_ids=model_inputs["input_ids"],512attention_mask=model_inputs["attention_mask"],513min_length=min_len,514max_length=max_len,515do_sample=do_sample,516early_stopping=True,517num_beams=1 if do_sample else n_beams,518temperature=temp,519top_k=top_k,520top_p=top_p,521eos_token_id=qa_s2s_tokenizer.eos_token_id,522no_repeat_ngram_size=3,523num_return_sequences=num_answers,524decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,525)526return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]527528529###############530# ELI5-trained retrieval model usage531###############532def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):533a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)534a_ids, a_mask = (535torch.LongTensor(a_toks["input_ids"]).to(device),536torch.LongTensor(a_toks["attention_mask"]).to(device),537)538with torch.no_grad():539a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)540return a_reps.numpy()541542543def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):544q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)545q_ids, q_mask = (546torch.LongTensor(q_toks["input_ids"]).to(device),547torch.LongTensor(q_toks["attention_mask"]).to(device),548)549with torch.no_grad():550q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)551return q_reps.numpy()552553554def make_qa_dense_index(555qa_embedder,556tokenizer,557passages_dset,558batch_size=512,559max_length=128,560index_name="kilt_passages_reps.dat",561dtype="float32",562device="cuda:0",563):564st_time = time()565fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))566n_batches = math.ceil(passages_dset.num_rows / batch_size)567for i in range(n_batches):568passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]569reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)570fp[i * batch_size : (i + 1) * batch_size] = reps571if i % 50 == 0:572print(i, time() - st_time)573574575def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):576total_retriever_time = 0.0577total_retriever_score = 0.0578st_time = time()579for i, (question, answer) in enumerate(qa_list):580r_time = time()581retrieved_passages = retriever_func(question, n_ret)582total_retriever_time += time() - r_time583total_retriever_score += scoring_func(retrieved_passages, answer)584if verbose and ((i + 1) % 500 == 0 or i <= 1):585print(586"{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(587i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time588)589)590return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}591592593# build a support document for the question out of Wikipedia snippets594def query_qa_dense_index(595question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0"596):597q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)598D, I = wiki_index.search(q_rep, 2 * n_results)599res_passages = [wiki_passages[int(i)] for i in I[0]]600support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])601res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]602res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]603for r, sc in zip(res_list, D[0]):604r["score"] = float(sc)605return support_doc, res_list606607608def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):609q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder)610D, I = wiki_index.search(q_rep, n_results)611res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]612support_doc_lst = [613"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst614]615all_res_lists = []616for (res_passages, dl) in zip(res_passages_lst, D):617res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]618for r, sc in zip(res_list, dl):619r["score"] = float(sc)620all_res_lists += [res_list[:]]621return support_doc_lst, all_res_lists622623624# find nearest neighbors of an answer or declarative text in Wikipedia snippets625def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):626a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)627D, I = wiki_index.search(a_rep, 2 * n_results)628res_passages = [wiki_passages[int(i)] for i in I[0]]629support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])630res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]631res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]632for r, sc, i in zip(res_list, D[0], I[0]):633r["passage_id"] = int(i)634r["score"] = float(sc)635return support_doc, res_list636637638def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):639a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)640D, I = wiki_index.search(a_reps, n_results)641res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]642support_doc_lst = [643"<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst644]645all_res_lists = []646for (res_passages, dl, il) in zip(res_passages_lst, D, I):647res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]648for r, sc, i in zip(res_list, dl, il):649r["passage_id"] = int(i)650r["score"] = float(sc)651all_res_lists += [res_list[:]]652return support_doc_lst, all_res_lists653654655