Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a4/model_embeddings.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
"""
5
CS224N 2022-23: Homework 4
6
model_embeddings.py: Embeddings for the NMT model
7
Pencheng Yin <[email protected]>
8
Sahil Chopra <[email protected]>
9
Anand Dhoot <[email protected]>
10
Vera Lin <[email protected]>
11
Siyan Li <[email protected]>
12
"""
13
14
import torch.nn as nn
15
16
class ModelEmbeddings(nn.Module):
17
"""
18
Class that converts input words to their embeddings.
19
"""
20
def __init__(self, embed_size, vocab):
21
"""
22
Init the Embedding layers.
23
24
@param embed_size (int): Embedding size (dimensionality)
25
@param vocab (Vocab): Vocabulary object containing src and tgt languages
26
See vocab.py for documentation.
27
"""
28
super(ModelEmbeddings, self).__init__()
29
self.embed_size = embed_size
30
31
# default values
32
self.source = None
33
self.target = None
34
35
src_pad_token_idx = vocab.src['<pad>']
36
tgt_pad_token_idx = vocab.tgt['<pad>']
37
38
### YOUR CODE HERE (~2 Lines)
39
### TODO - Initialize the following variables:
40
### self.source (Embedding Layer for source language)
41
### self.target (Embedding Layer for target langauge)
42
###
43
### Note:
44
### 1. `vocab` object contains two vocabularies:
45
### `vocab.src` for source
46
### `vocab.tgt` for target
47
### 2. You can get the length of a specific vocabulary by running:
48
### `len(vocab.<specific_vocabulary>)`
49
### 3. Remember to include the padding token for the specific vocabulary
50
### when creating your Embedding.
51
###
52
### Use the following docs to properly initialize these variables:
53
### Embedding Layer:
54
### https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding
55
self.source = nn.Embedding(len(vocab.src), embed_size, src_pad_token_idx)
56
self.target = nn.Embedding(len(vocab.tgt), embed_size, tgt_pad_token_idx)
57
58
### END YOUR CODE
59
60
61
62