Path: blob/main/examples/retrain_attention_model.py
1233 views
# pylint: skip-file1###################2"""3IMPORTANT:4THE EXAMPLE IN THIS FILE IS CURRENTLY NOT FUNCTIONAL5BECAUSE THE `download_from_public_repository` FUNCTION6NO LONGER EXISTS. WE HAD TO MAKE A QUICK RELEASE TO7REMEDIATE AN ISSUE IN OUR PREVIOUS STORAGE SOLUTION.8THIS WILL BE FIXED IN A FUTURE RELEASE.910IN THE MEAN TIME IF YOU NEED ANY CLARIFICATION11REGARDING THE PACKAGE PLEASE FEEL FREE TO OPEN AN ISSUE.12"""13import os1415import poutyne1617from deepparse import download_from_public_repository18from deepparse.dataset_container import PickleDatasetContainer19from deepparse.parser import AddressParser2021# First, let's download the train and test data with the new tags, "new tags", from the public repository.22saving_dir = "./data"23file_extension = "p"24training_dataset_name = "sample_incomplete_data"25test_dataset_name = "test_sample_data"26download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)27download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)2829# Now let's create a training and test container.30training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))31test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))3233# We will retrain the FastText attention version of our pretrained model.34model = "bpemb"35address_parser = AddressParser(model_type=model, device=0, attention_mechanism=True)3637# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.38# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate39# as we progress.40lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch4142# The path to save our checkpoints43logging_path = "./checkpoints"4445address_parser.retrain(46training_container,47train_ratio=0.8,48epochs=5,49batch_size=8,50num_workers=2,51callbacks=[lr_scheduler],52logging_path=logging_path,53layers_to_freeze="seq2seq",54)5556# Now, let's test our fine-tuned model using the best checkpoint (default parameter).57address_parser.test(test_container, batch_size=256)585960