Path: blob/main/a3/parser_transitions.py
984 views
#!/usr/bin/env python31# -*- coding: utf-8 -*-2"""3CS224N 2021-2022: Homework 34parser_transitions.py: Algorithms for completing partial parsess.5Sahil Chopra <[email protected]>6Haoshen Hong <[email protected]>7"""89import sys1011class PartialParse(object):12def __init__(self, sentence):13"""Initializes this partial parse.1415@param sentence (list of str): The sentence to be parsed as a list of words.16Your code should not modify the sentence.17"""18# The sentence being parsed is kept for bookkeeping purposes. Do NOT alter it in your code.19self.sentence = sentence2021### YOUR CODE HERE (3 Lines)22### Your code should initialize the following fields:23### self.stack: The current stack represented as a list with the top of the stack as the24### last element of the list.25### self.buffer: The current buffer represented as a list with the first item on the26### buffer as the first item of the list27### self.dependencies: The list of dependencies produced so far. Represented as a list of28### tuples where each tuple is of the form (head, dependent).29### Order for this list doesn't matter.30###31### Note: The root token should be represented with the string "ROOT"32### Note: If you need to use the sentence object to initialize anything, make sure to not directly33### reference the sentence object. That is, remember to NOT modify the sentence object.34self.stack = ['ROOT']35self.buffer = sentence.copy()36self.dependencies = []37### END YOUR CODE383940def parse_step(self, transition):41"""Performs a single parse step by applying the given transition to this partial parse4243@param transition (str): A string that equals "S", "LA", or "RA" representing the shift,44left-arc, and right-arc transitions. You can assume the provided45transition is a legal transition.46"""47### YOUR CODE HERE (~7-12 Lines)48### TODO:49### Implement a single parsing step, i.e. the logic for the following as50### described in the pdf handout:51### 1. Shift52### 2. Left Arc53### 3. Right Arc54if transition == 'S':55self.stack.append(self.buffer[0])56del self.buffer[0]57elif transition == 'LA':58self.dependencies.append((self.stack[-1], self.stack[-2]))59del self.stack[-2]60elif transition == 'RA':61self.dependencies.append((self.stack[-2], self.stack[-1]))62del self.stack[-1]63### END YOUR CODE6465def parse(self, transitions):66"""Applies the provided transitions to this PartialParse6768@param transitions (list of str): The list of transitions in the order they should be applied6970@return dependencies (list of string tuples): The list of dependencies produced when71parsing the sentence. Represented as a list of72tuples where each tuple is of the form (head, dependent).73"""74for transition in transitions:75self.parse_step(transition)76return self.dependencies777879def minibatch_parse(sentences, model, batch_size):80"""Parses a list of sentences in minibatches using a model.8182@param sentences (list of list of str): A list of sentences to be parsed83(each sentence is a list of words and each word is of type string)84@param model (ParserModel): The model that makes parsing decisions. It is assumed to have a function85model.predict(partial_parses) that takes in a list of PartialParses as input and86returns a list of transitions predicted for each parse. That is, after calling87transitions = model.predict(partial_parses)88transitions[i] will be the next transition to apply to partial_parses[i].89@param batch_size (int): The number of PartialParses to include in each minibatch909192@return dependencies (list of dependency lists): A list where each element is the dependencies93list for a parsed sentence. Ordering should be the94same as in sentences (i.e., dependencies[i] should95contain the parse for sentences[i]).96"""97dependencies = []9899### YOUR CODE HERE (~8-10 Lines)100### TODO:101### Implement the minibatch parse algorithm. Note that the pseudocode for this algorithm is given in the pdf handout.102###103### Note: A shallow copy (as denoted in the PDF) can be made with the "=" sign in python, e.g.104### unfinished_parses = partial_parses[:].105### Here `unfinished_parses` is a shallow copy of `partial_parses`.106### In Python, a shallow copied list like `unfinished_parses` does not contain new instances107### of the object stored in `partial_parses`. Rather both lists refer to the same objects.108### In our case, `partial_parses` contains a list of partial parses. `unfinished_parses`109### contains references to the same objects. Thus, you should NOT use the `del` operator110### to remove objects from the `unfinished_parses` list. This will free the underlying memory that111### is being accessed by `partial_parses` and may cause your code to crash.112partial_parses = [PartialParse(sentence) for sentence in sentences]113unfinished_parses = partial_parses[:]114while len(unfinished_parses) > 0:115batch = unfinished_parses[:batch_size]116transitions = model.predict(batch)117#print(batch_size,transitions)118for partial_parse, transition in zip(batch, transitions):119#the transition is one step for the partial_parse120trans = partial_parse.parse_step(transition)121if not partial_parse.buffer and len(partial_parse.stack) == 1:122unfinished_parses.remove(partial_parse)123dependencies = [parse.dependencies for parse in partial_parses]124### END YOUR CODE125126return dependencies127128129def test_step(name, transition, stack, buf, deps,130ex_stack, ex_buf, ex_deps):131"""Tests that a single parse step returns the expected output"""132pp = PartialParse([])133pp.stack, pp.buffer, pp.dependencies = stack, buf, deps134135pp.parse_step(transition)136stack, buf, deps = (tuple(pp.stack), tuple(pp.buffer), tuple(sorted(pp.dependencies)))137assert stack == ex_stack, \138"{:} test resulted in stack {:}, expected {:}".format(name, stack, ex_stack)139assert buf == ex_buf, \140"{:} test resulted in buffer {:}, expected {:}".format(name, buf, ex_buf)141assert deps == ex_deps, \142"{:} test resulted in dependency list {:}, expected {:}".format(name, deps, ex_deps)143print("{:} test passed!".format(name))144145146def test_parse_step():147"""Simple tests for the PartialParse.parse_step function148Warning: these are not exhaustive149"""150test_step("SHIFT", "S", ["ROOT", "the"], ["cat", "sat"], [],151("ROOT", "the", "cat"), ("sat",), ())152test_step("LEFT-ARC", "LA", ["ROOT", "the", "cat"], ["sat"], [],153("ROOT", "cat",), ("sat",), (("cat", "the"),))154test_step("RIGHT-ARC", "RA", ["ROOT", "run", "fast"], [], [],155("ROOT", "run",), (), (("run", "fast"),))156157158def test_parse():159"""Simple tests for the PartialParse.parse function160Warning: these are not exhaustive161"""162sentence = ["parse", "this", "sentence"]163dependencies = PartialParse(sentence).parse(["S", "S", "S", "LA", "RA", "RA"])164dependencies = tuple(sorted(dependencies))165expected = (('ROOT', 'parse'), ('parse', 'sentence'), ('sentence', 'this'))166assert dependencies == expected, \167"parse test resulted in dependencies {:}, expected {:}".format(dependencies, expected)168assert tuple(sentence) == ("parse", "this", "sentence"), \169"parse test failed: the input sentence should not be modified"170print("parse test passed!")171172173class DummyModel(object):174"""Dummy model for testing the minibatch_parse function175"""176def __init__(self, mode = "unidirectional"):177self.mode = mode178179def predict(self, partial_parses):180if self.mode == "unidirectional":181return self.unidirectional_predict(partial_parses)182elif self.mode == "interleave":183return self.interleave_predict(partial_parses)184else:185raise NotImplementedError()186187def unidirectional_predict(self, partial_parses):188"""First shifts everything onto the stack and then does exclusively right arcs if the first word of189the sentence is "right", "left" if otherwise.190"""191return [("RA" if pp.stack[1] is "right" else "LA") if len(pp.buffer) == 0 else "S"192for pp in partial_parses]193194def interleave_predict(self, partial_parses):195"""First shifts everything onto the stack and then interleaves "right" and "left".196"""197return [("RA" if len(pp.stack) % 2 == 0 else "LA") if len(pp.buffer) == 0 else "S"198for pp in partial_parses]199200def test_dependencies(name, deps, ex_deps):201"""Tests the provided dependencies match the expected dependencies"""202deps = tuple(sorted(deps))203assert deps == ex_deps, \204"{:} test resulted in dependency list {:}, expected {:}".format(name, deps, ex_deps)205206207def test_minibatch_parse():208"""Simple tests for the minibatch_parse function209Warning: these are not exhaustive210"""211212# Unidirectional arcs test213sentences = [["right", "arcs", "only"],214["right", "arcs", "only", "again"],215["left", "arcs", "only"],216["left", "arcs", "only", "again"]]217deps = minibatch_parse(sentences, DummyModel(), 2)218test_dependencies("minibatch_parse", deps[0],219(('ROOT', 'right'), ('arcs', 'only'), ('right', 'arcs')))220test_dependencies("minibatch_parse", deps[1],221(('ROOT', 'right'), ('arcs', 'only'), ('only', 'again'), ('right', 'arcs')))222test_dependencies("minibatch_parse", deps[2],223(('only', 'ROOT'), ('only', 'arcs'), ('only', 'left')))224test_dependencies("minibatch_parse", deps[3],225(('again', 'ROOT'), ('again', 'arcs'), ('again', 'left'), ('again', 'only')))226227# Out-of-bound test228sentences = [["right"]]229deps = minibatch_parse(sentences, DummyModel(), 2)230test_dependencies("minibatch_parse", deps[0], (('ROOT', 'right'),))231232# Mixed arcs test233sentences = [["this", "is", "interleaving", "dependency", "test"]]234deps = minibatch_parse(sentences, DummyModel(mode="interleave"), 1)235test_dependencies("minibatch_parse", deps[0],236(('ROOT', 'is'), ('dependency', 'interleaving'),237('dependency', 'test'), ('is', 'dependency'), ('is', 'this')))238print("minibatch_parse test passed!")239240241if __name__ == '__main__':242args = sys.argv243if len(args) != 2:244raise Exception("You did not provide a valid keyword. Either provide 'part_c' or 'part_d', when executing this script")245elif args[1] == "part_c":246test_parse_step()247test_parse()248elif args[1] == "part_d":249test_minibatch_parse()250else:251raise Exception("You did not provide a valid keyword. Either provide 'part_c' or 'part_d', when executing this script")252253254