Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
GRAAL-Research
GitHub Repository: GRAAL-Research/deepparse
Path: blob/main/examples/retrain_with_new_seq2seq_params.py
1234 views
1
# pylint: skip-file
2
###################
3
"""
4
IMPORTANT:
5
THE EXAMPLE IN THIS FILE IS CURRENTLY NOT FUNCTIONAL
6
BECAUSE THE `download_from_public_repository` FUNCTION
7
NO LONGER EXISTS. WE HAD TO MAKE A QUICK RELEASE TO
8
REMEDIATE AN ISSUE IN OUR PREVIOUS STORAGE SOLUTION.
9
THIS WILL BE FIXED IN A FUTURE RELEASE.
10
11
IN THE MEAN TIME IF YOU NEED ANY CLARIFICATION
12
REGARDING THE PACKAGE PLEASE FEEL FREE TO OPEN AN ISSUE.
13
"""
14
import os
15
16
import poutyne
17
18
from deepparse import download_from_public_repository
19
from deepparse.dataset_container import PickleDatasetContainer
20
from deepparse.parser import AddressParser
21
22
# First, let's download the train and test data with the new tags, "new tags", from the public repository.
23
saving_dir = "./data"
24
file_extension = "p"
25
training_dataset_name = "sample_incomplete_data_new_prediction_tags"
26
test_dataset_name = "test_sample_data_new_prediction_tags"
27
download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)
28
download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)
29
30
# Now let's create a training and test container.
31
training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))
32
test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))
33
34
# We will retrain the FastText version of our pretrained model.
35
model = "bpemb"
36
address_parser = AddressParser(model_type=model, device=0)
37
38
# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.
39
# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate
40
# as we progress.
41
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch
42
43
# We need a EOS tag in the dictionary. EOS -> End Of Sequence
44
tag_dictionary = {"ATag": 0, "AnotherTag": 1, "EOS": 2}
45
46
# The path to save our checkpoints
47
logging_path = "./checkpoints"
48
49
# The new seq2seq params settings using smaller hidden size
50
# See the documentation for the list of tunable seq2seq parameters
51
seq2seq_params = {"encoder_hidden_size": 512, "decoder_hidden_size": 512}
52
53
address_parser.retrain(
54
training_container,
55
train_ratio=0.8,
56
epochs=5,
57
batch_size=8,
58
num_workers=2,
59
callbacks=[lr_scheduler],
60
prediction_tags=tag_dictionary,
61
logging_path=logging_path,
62
seq2seq_params=seq2seq_params,
63
)
64
65
# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
66
address_parser.test(test_container, batch_size=256)
67
68