Path: blob/main/a3/utils/parser_utils.py
995 views
#!/usr/bin/env python31# -*- coding: utf-8 -*-2"""3CS224N 2021-2022: Homework 34parser_utils.py: Utilities for training the dependency parser.5Sahil Chopra <[email protected]>6"""78import time9import os10import logging11from collections import Counter12from . general_utils import get_minibatches13from parser_transitions import minibatch_parse1415from tqdm import tqdm16import torch17import numpy as np1819P_PREFIX = '<p>:'20L_PREFIX = '<l>:'21UNK = '<UNK>'22NULL = '<NULL>'23ROOT = '<ROOT>'242526class Config(object):27language = 'english'28with_punct = True29unlabeled = True30lowercase = True31use_pos = True32use_dep = True33use_dep = use_dep and (not unlabeled)34data_path = './data'35train_file = 'train.conll'36dev_file = 'dev.conll'37test_file = 'test.conll'38embedding_file = './data/en-cw.txt'394041class Parser(object):42"""Contains everything needed for transition-based dependency parsing except for the model"""4344def __init__(self, dataset):45root_labels = list([l for ex in dataset46for (h, l) in zip(ex['head'], ex['label']) if h == 0])47counter = Counter(root_labels)48if len(counter) > 1:49logging.info('Warning: more than one root label')50logging.info(counter)51self.root_label = counter.most_common()[0][0]52deprel = [self.root_label] + list(set([w for ex in dataset53for w in ex['label']54if w != self.root_label]))55tok2id = {L_PREFIX + l: i for (i, l) in enumerate(deprel)}56tok2id[L_PREFIX + NULL] = self.L_NULL = len(tok2id)5758config = Config()59self.unlabeled = config.unlabeled60self.with_punct = config.with_punct61self.use_pos = config.use_pos62self.use_dep = config.use_dep63self.language = config.language6465if self.unlabeled:66trans = ['L', 'R', 'S']67self.n_deprel = 168else:69trans = ['L-' + l for l in deprel] + ['R-' + l for l in deprel] + ['S']70self.n_deprel = len(deprel)7172self.n_trans = len(trans)73self.tran2id = {t: i for (i, t) in enumerate(trans)}74self.id2tran = {i: t for (i, t) in enumerate(trans)}7576# logging.info('Build dictionary for part-of-speech tags.')77tok2id.update(build_dict([P_PREFIX + w for ex in dataset for w in ex['pos']],78offset=len(tok2id)))79tok2id[P_PREFIX + UNK] = self.P_UNK = len(tok2id)80tok2id[P_PREFIX + NULL] = self.P_NULL = len(tok2id)81tok2id[P_PREFIX + ROOT] = self.P_ROOT = len(tok2id)8283# logging.info('Build dictionary for words.')84tok2id.update(build_dict([w for ex in dataset for w in ex['word']],85offset=len(tok2id)))86tok2id[UNK] = self.UNK = len(tok2id)87tok2id[NULL] = self.NULL = len(tok2id)88tok2id[ROOT] = self.ROOT = len(tok2id)8990self.tok2id = tok2id91self.id2tok = {v: k for (k, v) in tok2id.items()}9293self.n_features = 18 + (18 if config.use_pos else 0) + (12 if config.use_dep else 0)94self.n_tokens = len(tok2id)9596def vectorize(self, examples):97vec_examples = []98for ex in examples:99word = [self.ROOT] + [self.tok2id[w] if w in self.tok2id100else self.UNK for w in ex['word']]101pos = [self.P_ROOT] + [self.tok2id[P_PREFIX + w] if P_PREFIX + w in self.tok2id102else self.P_UNK for w in ex['pos']]103head = [-1] + ex['head']104label = [-1] + [self.tok2id[L_PREFIX + w] if L_PREFIX + w in self.tok2id105else -1 for w in ex['label']]106vec_examples.append({'word': word, 'pos': pos,107'head': head, 'label': label})108return vec_examples109110def extract_features(self, stack, buf, arcs, ex):111if stack[0] == "ROOT":112stack[0] = 0113114def get_lc(k):115return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] < k])116117def get_rc(k):118return sorted([arc[1] for arc in arcs if arc[0] == k and arc[1] > k],119reverse=True)120121p_features = []122l_features = []123features = [self.NULL] * (3 - len(stack)) + [ex['word'][x] for x in stack[-3:]]124features += [ex['word'][x] for x in buf[:3]] + [self.NULL] * (3 - len(buf))125if self.use_pos:126p_features = [self.P_NULL] * (3 - len(stack)) + [ex['pos'][x] for x in stack[-3:]]127p_features += [ex['pos'][x] for x in buf[:3]] + [self.P_NULL] * (3 - len(buf))128129for i in range(2):130if i < len(stack):131k = stack[-i-1]132lc = get_lc(k)133rc = get_rc(k)134llc = get_lc(lc[0]) if len(lc) > 0 else []135rrc = get_rc(rc[0]) if len(rc) > 0 else []136137features.append(ex['word'][lc[0]] if len(lc) > 0 else self.NULL)138features.append(ex['word'][rc[0]] if len(rc) > 0 else self.NULL)139features.append(ex['word'][lc[1]] if len(lc) > 1 else self.NULL)140features.append(ex['word'][rc[1]] if len(rc) > 1 else self.NULL)141features.append(ex['word'][llc[0]] if len(llc) > 0 else self.NULL)142features.append(ex['word'][rrc[0]] if len(rrc) > 0 else self.NULL)143144if self.use_pos:145p_features.append(ex['pos'][lc[0]] if len(lc) > 0 else self.P_NULL)146p_features.append(ex['pos'][rc[0]] if len(rc) > 0 else self.P_NULL)147p_features.append(ex['pos'][lc[1]] if len(lc) > 1 else self.P_NULL)148p_features.append(ex['pos'][rc[1]] if len(rc) > 1 else self.P_NULL)149p_features.append(ex['pos'][llc[0]] if len(llc) > 0 else self.P_NULL)150p_features.append(ex['pos'][rrc[0]] if len(rrc) > 0 else self.P_NULL)151152if self.use_dep:153l_features.append(ex['label'][lc[0]] if len(lc) > 0 else self.L_NULL)154l_features.append(ex['label'][rc[0]] if len(rc) > 0 else self.L_NULL)155l_features.append(ex['label'][lc[1]] if len(lc) > 1 else self.L_NULL)156l_features.append(ex['label'][rc[1]] if len(rc) > 1 else self.L_NULL)157l_features.append(ex['label'][llc[0]] if len(llc) > 0 else self.L_NULL)158l_features.append(ex['label'][rrc[0]] if len(rrc) > 0 else self.L_NULL)159else:160features += [self.NULL] * 6161if self.use_pos:162p_features += [self.P_NULL] * 6163if self.use_dep:164l_features += [self.L_NULL] * 6165166features += p_features + l_features167assert len(features) == self.n_features168return features169170def get_oracle(self, stack, buf, ex):171if len(stack) < 2:172return self.n_trans - 1173174i0 = stack[-1]175i1 = stack[-2]176h0 = ex['head'][i0]177h1 = ex['head'][i1]178l0 = ex['label'][i0]179l1 = ex['label'][i1]180181if self.unlabeled:182if (i1 > 0) and (h1 == i0):183return 0184elif (i1 >= 0) and (h0 == i1) and \185(not any([x for x in buf if ex['head'][x] == i0])):186return 1187else:188return None if len(buf) == 0 else 2189else:190if (i1 > 0) and (h1 == i0):191return l1 if (l1 >= 0) and (l1 < self.n_deprel) else None192elif (i1 >= 0) and (h0 == i1) and \193(not any([x for x in buf if ex['head'][x] == i0])):194return l0 + self.n_deprel if (l0 >= 0) and (l0 < self.n_deprel) else None195else:196return None if len(buf) == 0 else self.n_trans - 1197198def create_instances(self, examples):199all_instances = []200succ = 0201for id, ex in enumerate(examples):202n_words = len(ex['word']) - 1203204# arcs = {(h, t, label)}205stack = [0]206buf = [i + 1 for i in range(n_words)]207arcs = []208instances = []209for i in range(n_words * 2):210gold_t = self.get_oracle(stack, buf, ex)211if gold_t is None:212break213legal_labels = self.legal_labels(stack, buf)214assert legal_labels[gold_t] == 1215instances.append((self.extract_features(stack, buf, arcs, ex),216legal_labels, gold_t))217if gold_t == self.n_trans - 1:218stack.append(buf[0])219buf = buf[1:]220elif gold_t < self.n_deprel:221arcs.append((stack[-1], stack[-2], gold_t))222stack = stack[:-2] + [stack[-1]]223else:224arcs.append((stack[-2], stack[-1], gold_t - self.n_deprel))225stack = stack[:-1]226else:227succ += 1228all_instances += instances229230return all_instances231232def legal_labels(self, stack, buf):233labels = ([1] if len(stack) > 2 else [0]) * self.n_deprel234labels += ([1] if len(stack) >= 2 else [0]) * self.n_deprel235labels += [1] if len(buf) > 0 else [0]236return labels237238def parse(self, dataset, eval_batch_size=5000):239sentences = []240sentence_id_to_idx = {}241for i, example in enumerate(dataset):242n_words = len(example['word']) - 1243sentence = [j + 1 for j in range(n_words)]244sentences.append(sentence)245sentence_id_to_idx[id(sentence)] = i246247model = ModelWrapper(self, dataset, sentence_id_to_idx)248dependencies = minibatch_parse(sentences, model, eval_batch_size)249250UAS = all_tokens = 0.0251with tqdm(total=len(dataset)) as prog:252for i, ex in enumerate(dataset):253head = [-1] * len(ex['word'])254for h, t, in dependencies[i]:255head[t] = h256for pred_h, gold_h, gold_l, pos in \257zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]):258assert self.id2tok[pos].startswith(P_PREFIX)259pos_str = self.id2tok[pos][len(P_PREFIX):]260if (self.with_punct) or (not punct(self.language, pos_str)):261UAS += 1 if pred_h == gold_h else 0262all_tokens += 1263prog.update(i + 1)264UAS /= all_tokens265return UAS, dependencies266267268class ModelWrapper(object):269def __init__(self, parser, dataset, sentence_id_to_idx):270self.parser = parser271self.dataset = dataset272self.sentence_id_to_idx = sentence_id_to_idx273274def predict(self, partial_parses):275mb_x = [self.parser.extract_features(p.stack, p.buffer, p.dependencies,276self.dataset[self.sentence_id_to_idx[id(p.sentence)]])277for p in partial_parses]278mb_x = np.array(mb_x).astype('int32')279mb_x = torch.from_numpy(mb_x).long()280mb_l = [self.parser.legal_labels(p.stack, p.buffer) for p in partial_parses]281282pred = self.parser.model(mb_x)283pred = pred.detach().numpy()284pred = np.argmax(pred + 10000 * np.array(mb_l).astype('float32'), 1)285pred = ["S" if p == 2 else ("LA" if p == 0 else "RA") for p in pred]286return pred287288289def read_conll(in_file, lowercase=False, max_example=None):290examples = []291with open(in_file) as f:292word, pos, head, label = [], [], [], []293for line in f.readlines():294sp = line.strip().split('\t')295if len(sp) == 10:296if '-' not in sp[0]:297word.append(sp[1].lower() if lowercase else sp[1])298pos.append(sp[4])299head.append(int(sp[6]))300label.append(sp[7])301elif len(word) > 0:302examples.append({'word': word, 'pos': pos, 'head': head, 'label': label})303word, pos, head, label = [], [], [], []304if (max_example is not None) and (len(examples) == max_example):305break306if len(word) > 0:307examples.append({'word': word, 'pos': pos, 'head': head, 'label': label})308return examples309310311def build_dict(keys, n_max=None, offset=0):312count = Counter()313for key in keys:314count[key] += 1315ls = count.most_common() if n_max is None \316else count.most_common(n_max)317318return {w[0]: index + offset for (index, w) in enumerate(ls)}319320321def punct(language, pos):322if language == 'english':323return pos in ["''", ",", ".", ":", "``", "-LRB-", "-RRB-"]324elif language == 'chinese':325return pos == 'PU'326elif language == 'french':327return pos == 'PUNC'328elif language == 'german':329return pos in ["$.", "$,", "$["]330elif language == 'spanish':331# http://nlp.stanford.edu/software/spanish-faq.shtml332return pos in ["f0", "faa", "fat", "fc", "fd", "fe", "fg", "fh",333"fia", "fit", "fp", "fpa", "fpt", "fs", "ft",334"fx", "fz"]335elif language == 'universal':336return pos == 'PUNCT'337else:338raise ValueError('language: %s is not supported.' % language)339340341def minibatches(data, batch_size):342x = np.array([d[0] for d in data])343y = np.array([d[2] for d in data])344one_hot = np.zeros((y.size, 3))345one_hot[np.arange(y.size), y] = 1346return get_minibatches([x, one_hot], batch_size)347348349def load_and_preprocess_data(reduced=True):350config = Config()351352print("Loading data...",)353start = time.time()354train_set = read_conll(os.path.join(config.data_path, config.train_file),355lowercase=config.lowercase)356dev_set = read_conll(os.path.join(config.data_path, config.dev_file),357lowercase=config.lowercase)358test_set = read_conll(os.path.join(config.data_path, config.test_file),359lowercase=config.lowercase)360if reduced:361train_set = train_set[:1000]362dev_set = dev_set[:500]363test_set = test_set[:500]364print("took {:.2f} seconds".format(time.time() - start))365366print("Building parser...",)367start = time.time()368parser = Parser(train_set)369print("took {:.2f} seconds".format(time.time() - start))370371print("Loading pretrained embeddings...",)372start = time.time()373word_vectors = {}374for line in open(config.embedding_file).readlines():375sp = line.strip().split()376word_vectors[sp[0]] = [float(x) for x in sp[1:]]377embeddings_matrix = np.asarray(np.random.normal(0, 0.9, (parser.n_tokens, 50)), dtype='float32')378379for token in parser.tok2id:380i = parser.tok2id[token]381if token in word_vectors:382embeddings_matrix[i] = word_vectors[token]383elif token.lower() in word_vectors:384embeddings_matrix[i] = word_vectors[token.lower()]385print("took {:.2f} seconds".format(time.time() - start))386387print("Vectorizing data...",)388start = time.time()389train_set = parser.vectorize(train_set)390dev_set = parser.vectorize(dev_set)391test_set = parser.vectorize(test_set)392print("took {:.2f} seconds".format(time.time() - start))393394print("Preprocessing training data...",)395start = time.time()396train_examples = parser.create_instances(train_set)397print("took {:.2f} seconds".format(time.time() - start))398399return parser, embeddings_matrix, train_examples, dev_set, test_set,400401402class AverageMeter(object):403"""Computes and stores the average and current value"""404def __init__(self):405self.reset()406407def reset(self):408self.val = 0409self.avg = 0410self.sum = 0411self.count = 0412413def update(self, val, n=1):414self.val = val415self.sum += val * n416self.count += n417self.avg = self.sum / self.count418419420if __name__ == '__main__':421pass422423424