Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
GRAAL-Research
GitHub Repository: GRAAL-Research/deepparse
Path: blob/main/examples/fine_tuning.py
1231 views
1
# pylint: skip-file
2
###################
3
"""
4
IMPORTANT:
5
THE EXAMPLE IN THIS FILE IS CURRENTLY NOT FUNCTIONAL
6
BECAUSE THE `download_from_public_repository` FUNCTION
7
NO LONGER EXISTS. WE HAD TO MAKE A QUICK RELEASE TO
8
REMEDIATE AN ISSUE IN OUR PREVIOUS STORAGE SOLUTION.
9
THIS WILL BE FIXED IN A FUTURE RELEASE.
10
11
IN THE MEAN TIME IF YOU NEED ANY CLARIFICATION
12
REGARDING THE PACKAGE PLEASE FEEL FREE TO OPEN AN ISSUE.
13
"""
14
import os
15
import poutyne
16
17
from deepparse import download_from_public_repository
18
from deepparse.dataset_container import PickleDatasetContainer
19
from deepparse.parser import AddressParser
20
21
# First, let's download the train and test data from the public repository.
22
saving_dir = "./data"
23
file_extension = "p"
24
training_dataset_name = "sample_incomplete_data"
25
test_dataset_name = "test_sample_data"
26
download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)
27
download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)
28
29
# Now let's create a training and test container.
30
training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))
31
test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))
32
33
# We will retrain the FastText version of our pretrained model.
34
address_parser = AddressParser(model_type="fasttext", device=0)
35
36
# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.
37
# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate
38
# as we progress.
39
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch
40
41
# The checkpoints (ckpt) are saved in the default "./checkpoints" directory, so if you wish to retrain
42
# another model (let's say BPEmb), you need to change the `logging_path` directory; otherwise, you will get
43
# an error when retraining since Poutyne will try to use the last checkpoint.
44
address_parser.retrain(
45
training_container,
46
train_ratio=0.8,
47
epochs=5,
48
batch_size=8,
49
num_workers=2,
50
callbacks=[lr_scheduler],
51
)
52
53
# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
54
address_parser.test(test_container, batch_size=256)
55
56
# Now let's retrain the FastText version but with an attention mechanism.
57
address_parser = AddressParser(model_type="fasttext", device=0, attention_mechanism=True)
58
59
# Since the previous checkpoints were saved in the default "./checkpoints" directory, we need to use a new one.
60
# Otherwise, poutyne will try to reload the previous checkpoints, and our model has changed.
61
address_parser.retrain(
62
training_container,
63
train_ratio=0.8,
64
epochs=5,
65
batch_size=8,
66
num_workers=2,
67
callbacks=[lr_scheduler],
68
logging_path="checkpoints_attention",
69
)
70
71
# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
72
address_parser.test(test_container, batch_size=256)
73
74