Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a4/utils.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
"""
5
CS224N 2022-23: Homework 4
6
utils.py: Utility Functions
7
Pencheng Yin <[email protected]>
8
Sahil Chopra <[email protected]>
9
Vera Lin <[email protected]>
10
Siyan Li <[email protected]>
11
"""
12
13
import math
14
from typing import List
15
16
import numpy as np
17
import torch
18
import torch.nn as nn
19
import torch.nn.functional as F
20
import nltk
21
import sentencepiece as spm
22
#nltk.download('punkt')
23
24
25
def pad_sents(sents, pad_token):
26
""" Pad list of sentences according to the longest sentence in the batch.
27
The paddings should be at the end of each sentence.
28
@param sents (list[list[str]]): list of sentences, where each sentence
29
is represented as a list of words
30
@param pad_token (str): padding token
31
@returns sents_padded (list[list[str]]): list of sentences where sentences shorter
32
than the max length sentence are padded out with the pad_token, such that
33
each sentences in the batch now has equal length.
34
"""
35
sents_padded = []
36
37
### YOUR CODE HERE (~6 Lines)
38
max_length = max([len(sent) for sent in sents])
39
sents_padded = [sentence + [pad_token]*(max_length-len(sentence)) for sentence in sents]
40
### END YOUR CODE
41
42
return sents_padded
43
44
45
def read_corpus(file_path, source, vocab_size=2500):
46
""" Read file, where each sentence is dilineated by a `\n`.
47
@param file_path (str): path to file containing corpus
48
@param source (str): "tgt" or "src" indicating whether text
49
is of the source language or target language
50
@param vocab_size (int): number of unique subwords in
51
vocabulary when reading and tokenizing
52
"""
53
data = []
54
sp = spm.SentencePieceProcessor()
55
sp.load('{}.model'.format(source))
56
57
with open(file_path, 'r', encoding='utf8') as f:
58
for line in f:
59
subword_tokens = sp.encode_as_pieces(line)
60
# only append <s> and </s> to the target sentence
61
if source == 'tgt':
62
subword_tokens = ['<s>'] + subword_tokens + ['</s>']
63
data.append(subword_tokens)
64
65
return data
66
67
68
def autograder_read_corpus(file_path, source):
69
""" Read file, where each sentence is dilineated by a `\n`.
70
@param file_path (str): path to file containing corpus
71
@param source (str): "tgt" or "src" indicating whether text
72
is of the source language or target language
73
"""
74
data = []
75
for line in open(file_path):
76
sent = nltk.word_tokenize(line)
77
# only append <s> and </s> to the target sentence
78
if source == 'tgt':
79
sent = ['<s>'] + sent + ['</s>']
80
data.append(sent)
81
82
return data
83
84
85
def batch_iter(data, batch_size, shuffle=False):
86
""" Yield batches of source and target sentences reverse sorted by length (largest to smallest).
87
@param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
88
@param batch_size (int): batch size
89
@param shuffle (boolean): whether to randomly shuffle the dataset
90
"""
91
batch_num = math.ceil(len(data) / batch_size)
92
index_array = list(range(len(data)))
93
94
if shuffle:
95
np.random.shuffle(index_array)
96
97
for i in range(batch_num):
98
indices = index_array[i * batch_size: (i + 1) * batch_size]
99
examples = [data[idx] for idx in indices]
100
101
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
102
src_sents = [e[0] for e in examples]
103
tgt_sents = [e[1] for e in examples]
104
105
yield src_sents, tgt_sents
106
107
108