Path: blob/main/examples/retrain_with_new_seq2seq_params.py
1234 views
# pylint: skip-file1###################2"""3IMPORTANT:4THE EXAMPLE IN THIS FILE IS CURRENTLY NOT FUNCTIONAL5BECAUSE THE `download_from_public_repository` FUNCTION6NO LONGER EXISTS. WE HAD TO MAKE A QUICK RELEASE TO7REMEDIATE AN ISSUE IN OUR PREVIOUS STORAGE SOLUTION.8THIS WILL BE FIXED IN A FUTURE RELEASE.910IN THE MEAN TIME IF YOU NEED ANY CLARIFICATION11REGARDING THE PACKAGE PLEASE FEEL FREE TO OPEN AN ISSUE.12"""13import os1415import poutyne1617from deepparse import download_from_public_repository18from deepparse.dataset_container import PickleDatasetContainer19from deepparse.parser import AddressParser2021# First, let's download the train and test data with the new tags, "new tags", from the public repository.22saving_dir = "./data"23file_extension = "p"24training_dataset_name = "sample_incomplete_data_new_prediction_tags"25test_dataset_name = "test_sample_data_new_prediction_tags"26download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)27download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)2829# Now let's create a training and test container.30training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))31test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))3233# We will retrain the FastText version of our pretrained model.34model = "bpemb"35address_parser = AddressParser(model_type=model, device=0)3637# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.38# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate39# as we progress.40lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch4142# We need a EOS tag in the dictionary. EOS -> End Of Sequence43tag_dictionary = {"ATag": 0, "AnotherTag": 1, "EOS": 2}4445# The path to save our checkpoints46logging_path = "./checkpoints"4748# The new seq2seq params settings using smaller hidden size49# See the documentation for the list of tunable seq2seq parameters50seq2seq_params = {"encoder_hidden_size": 512, "decoder_hidden_size": 512}5152address_parser.retrain(53training_container,54train_ratio=0.8,55epochs=5,56batch_size=8,57num_workers=2,58callbacks=[lr_scheduler],59prediction_tags=tag_dictionary,60logging_path=logging_path,61seq2seq_params=seq2seq_params,62)6364# Now, let's test our fine-tuned model using the best checkpoint (default parameter).65address_parser.test(test_container, batch_size=256)666768