Path: blob/main/examples/fine_tuning_with_csv_dataset.py
1232 views
# pylint: skip-file1###################2"""3IMPORTANT:4THE EXAMPLE IN THIS FILE IS CURRENTLY NOT FUNCTIONAL5BECAUSE THE `download_from_public_repository` FUNCTION6NO LONGER EXISTS. WE HAD TO MAKE A QUICK RELEASE TO7REMEDIATE AN ISSUE IN OUR PREVIOUS STORAGE SOLUTION.8THIS WILL BE FIXED IN A FUTURE RELEASE.910IN THE MEAN TIME IF YOU NEED ANY CLARIFICATION11REGARDING THE PACKAGE PLEASE FEEL FREE TO OPEN AN ISSUE.12"""13import os1415import poutyne1617from deepparse import download_from_public_repository18from deepparse.dataset_container import CSVDatasetContainer19from deepparse.parser import AddressParser2021# First, let's download the train and test data from the public repository but using a CSV format dataset.22saving_dir = "./data"23file_extension = "csv"24training_dataset_name = "sample_incomplete_data"25test_dataset_name = "test_sample_data"26download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)27download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)2829# Now let's create a training and test container.30training_container = CSVDatasetContainer(31os.path.join(saving_dir, training_dataset_name + "." + file_extension),32column_names=["Address", "Tags"],33separator=",",34)35test_container = CSVDatasetContainer(36os.path.join(saving_dir, test_dataset_name + "." + file_extension),37column_names=["Address", "Tags"],38separator=",",39)4041# We will retrain the FastText version of our pretrained model.42address_parser = AddressParser(model_type="fasttext", device=0)4344# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.45# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate46# as we progress.47lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch4849# The checkpoints (ckpt) are saved in the default "./checkpoints" directory, so if you wish to retrain50# another model (let's say BPEmb), you need to change the `logging_path` directory; otherwise, you will get51# an error when retraining since Poutyne will try to use the last checkpoint.52address_parser.retrain(53training_container,54train_ratio=0.8,55epochs=5,56batch_size=8,57num_workers=2,58callbacks=[lr_scheduler],59)6061# Now, let's test our fine-tuned model using the best checkpoint (default parameter).62address_parser.test(test_container, batch_size=256)6364# Now let's retrain the FastText version but with an attention mechanism.65address_parser = AddressParser(model_type="fasttext", device=0, attention_mechanism=True)6667# Since the previous checkpoints were saved in the default "./checkpoints" directory, we need to use a new one.68# Otherwise, poutyne will try to reload the previous checkpoints, and our model has changed.69address_parser.retrain(70training_container,71train_ratio=0.8,72epochs=5,73batch_size=8,74num_workers=2,75callbacks=[lr_scheduler],76logging_path="checkpoints_attention",77)7879# Now, let's test our fine-tuned model using the best checkpoint (default parameter).80address_parser.test(test_container, batch_size=256)818283