Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a4/vocab.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
"""
5
CS224N 2022-23: Homework 4
6
vocab.py: Vocabulary Generation
7
Pencheng Yin <[email protected]>
8
Sahil Chopra <[email protected]>
9
Vera Lin <[email protected]>
10
Siyan Li <[email protected]>
11
12
Usage:
13
vocab.py --train-src=<file> --train-tgt=<file> [options] VOCAB_FILE
14
15
Options:
16
-h --help Show this screen.
17
--train-src=<file> File of training source sentences
18
--train-tgt=<file> File of training target sentences
19
--size=<int> vocab size [default: 50000]
20
--freq-cutoff=<int> frequency cutoff [default: 2]
21
"""
22
23
from collections import Counter
24
from docopt import docopt
25
from itertools import chain
26
import json
27
import torch
28
from typing import List
29
from utils import read_corpus, pad_sents
30
import sentencepiece as spm
31
32
33
class VocabEntry(object):
34
""" Vocabulary Entry, i.e. structure containing either
35
src or tgt language terms.
36
"""
37
def __init__(self, word2id=None):
38
""" Init VocabEntry Instance.
39
@param word2id (dict): dictionary mapping words 2 indices
40
"""
41
if word2id:
42
self.word2id = word2id
43
else:
44
self.word2id = dict()
45
self.word2id['<pad>'] = 0 # Pad Token
46
self.word2id['<s>'] = 1 # Start Token
47
self.word2id['</s>'] = 2 # End Token
48
self.word2id['<unk>'] = 3 # Unknown Token
49
self.unk_id = self.word2id['<unk>']
50
self.id2word = {v: k for k, v in self.word2id.items()}
51
52
def __getitem__(self, word):
53
""" Retrieve word's index. Return the index for the unk
54
token if the word is out of vocabulary.
55
@param word (str): word to look up.
56
@returns index (int): index of word
57
"""
58
return self.word2id.get(word, self.unk_id)
59
60
def __contains__(self, word):
61
""" Check if word is captured by VocabEntry.
62
@param word (str): word to look up
63
@returns contains (bool): whether word is contained
64
"""
65
return word in self.word2id
66
67
def __setitem__(self, key, value):
68
""" Raise error, if one tries to edit the VocabEntry.
69
"""
70
raise ValueError('vocabulary is readonly')
71
72
def __len__(self):
73
""" Compute number of words in VocabEntry.
74
@returns len (int): number of words in VocabEntry
75
"""
76
return len(self.word2id)
77
78
def __repr__(self):
79
""" Representation of VocabEntry to be used
80
when printing the object.
81
"""
82
return 'Vocabulary[size=%d]' % len(self)
83
84
def id2word(self, wid):
85
""" Return mapping of index to word.
86
@param wid (int): word index
87
@returns word (str): word corresponding to index
88
"""
89
return self.id2word[wid]
90
91
def add(self, word):
92
""" Add word to VocabEntry, if it is previously unseen.
93
@param word (str): word to add to VocabEntry
94
@return index (int): index that the word has been assigned
95
"""
96
if word not in self:
97
wid = self.word2id[word] = len(self)
98
self.id2word[wid] = word
99
return wid
100
else:
101
return self[word]
102
103
def words2indices(self, sents):
104
""" Convert list of words or list of sentences of words
105
into list or list of list of indices.
106
@param sents (list[str] or list[list[str]]): sentence(s) in words
107
@return word_ids (list[int] or list[list[int]]): sentence(s) in indices
108
"""
109
if type(sents[0]) == list:
110
return [[self[w] for w in s] for s in sents]
111
else:
112
return [self[w] for w in sents]
113
114
def indices2words(self, word_ids):
115
""" Convert list of indices into words.
116
@param word_ids (list[int]): list of word ids
117
@return sents (list[str]): list of words
118
"""
119
return [self.id2word[w_id] for w_id in word_ids]
120
121
def to_input_tensor(self, sents: List[List[str]], device: torch.device) -> torch.Tensor:
122
""" Convert list of sentences (words) into tensor with necessary padding for
123
shorter sentences.
124
125
@param sents (List[List[str]]): list of sentences (words)
126
@param device: device on which to load the tesnor, i.e. CPU or GPU
127
128
@returns sents_var: tensor of (max_sentence_length, batch_size)
129
"""
130
word_ids = self.words2indices(sents)
131
sents_t = pad_sents(word_ids, self['<pad>'])
132
sents_var = torch.tensor(sents_t, dtype=torch.long, device=device)
133
return torch.t(sents_var)
134
135
@staticmethod
136
def from_corpus(corpus, size, freq_cutoff=2):
137
""" Given a corpus construct a Vocab Entry.
138
@param corpus (list[str]): corpus of text produced by read_corpus function
139
@param size (int): # of words in vocabulary
140
@param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word
141
@returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus
142
"""
143
vocab_entry = VocabEntry()
144
word_freq = Counter(chain(*corpus))
145
valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
146
print('number of word types: {}, number of word types w/ frequency >= {}: {}'
147
.format(len(word_freq), freq_cutoff, len(valid_words)))
148
top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
149
for word in top_k_words:
150
vocab_entry.add(word)
151
return vocab_entry
152
153
@staticmethod
154
def from_subword_list(subword_list):
155
vocab_entry = VocabEntry()
156
for subword in subword_list:
157
vocab_entry.add(subword)
158
return vocab_entry
159
160
161
class Vocab(object):
162
""" Vocab encapsulating src and target langauges.
163
"""
164
def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry):
165
""" Init Vocab.
166
@param src_vocab (VocabEntry): VocabEntry for source language
167
@param tgt_vocab (VocabEntry): VocabEntry for target language
168
"""
169
self.src = src_vocab
170
self.tgt = tgt_vocab
171
172
@staticmethod
173
def build(src_sents, tgt_sents) -> 'Vocab':
174
""" Build Vocabulary.
175
@param src_sents (list[str]): Source subwords provided by SentencePiece
176
@param tgt_sents (list[str]): Target subwords provided by SentencePiece
177
"""
178
# assert len(src_sents) == len(tgt_sents)
179
180
print('initialize source vocabulary ..')
181
# src = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
182
src = VocabEntry.from_subword_list(src_sents)
183
184
print('initialize target vocabulary ..')
185
# tgt = VocabEntry.from_corpus(tgt_sents, vocab_size, freq_cutoff)
186
tgt = VocabEntry.from_subword_list(tgt_sents)
187
188
return Vocab(src, tgt)
189
190
def save(self, file_path):
191
""" Save Vocab to file as JSON dump.
192
@param file_path (str): file path to vocab file
193
"""
194
with open(file_path, 'w') as f:
195
json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), f, indent=2)
196
197
@staticmethod
198
def load(file_path):
199
""" Load vocabulary from JSON dump.
200
@param file_path (str): file path to vocab file
201
@returns Vocab object loaded from JSON dump
202
"""
203
entry = json.load(open(file_path, 'r'))
204
src_word2id = entry['src_word2id']
205
tgt_word2id = entry['tgt_word2id']
206
207
return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id))
208
209
def __repr__(self):
210
""" Representation of Vocab to be used
211
when printing the object.
212
"""
213
return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
214
215
216
def get_vocab_list(file_path, source, vocab_size):
217
""" Use SentencePiece to tokenize and acquire list of unique subwords.
218
@param file_path (str): file path to corpus
219
@param source (str): tgt or src
220
@param vocab_size: desired vocabulary size
221
"""
222
spm.SentencePieceTrainer.Train(input=file_path, model_prefix=source, vocab_size=vocab_size) # train the spm model
223
sp = spm.SentencePieceProcessor() # create an instance; this saves .model and .vocab files
224
sp.Load('{}.model'.format(source)) # loads tgt.model or src.model
225
sp_list = [sp.IdToPiece(piece_id) for piece_id in range(sp.GetPieceSize())] # this is the list of subwords
226
return sp_list
227
228
229
230
if __name__ == '__main__':
231
args = docopt(__doc__)
232
233
print('read in source sentences: %s' % args['--train-src'])
234
print('read in target sentences: %s' % args['--train-tgt'])
235
236
src_sents = get_vocab_list(args['--train-src'], source='src', vocab_size=21000) # EDIT: NEW VOCAB SIZE
237
tgt_sents = get_vocab_list(args['--train-tgt'], source='tgt', vocab_size=8000)
238
vocab = Vocab.build(src_sents, tgt_sents)
239
print('generated vocabulary, source %d words, target %d words' % (len(src_sents), len(tgt_sents)))
240
241
# src_sents = read_corpus(args['--train-src'], source='src')
242
# tgt_sents = read_corpus(args['--train-tgt'], source='tgt')
243
244
# vocab = Vocab.build(src_sents, tgt_sents, int(args['--size']), int(args['--freq-cutoff']))
245
# print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))
246
247
vocab.save(args['VOCAB_FILE'])
248
print('vocabulary saved to %s' % args['VOCAB_FILE'])
249
250