Path: blob/master/deep_learning/multi_label/fasttext_module/split.py
1487 views
import random1from typing import Tuple234__all__ = ['train_test_split_file']567def train_test_split_file(input_path: str,8output_path_train: str,9output_path_test: str,10test_size: float=0.1,11random_state: int=1234,12encoding: str='utf-8') -> Tuple[int, int]:13"""14Perform train and test split on a text file without reading the15whole file into memory.1617Parameters18----------19input_path : str20Path to the original full text file.2122output_path_train : str23Path of the train split.2425output_path_test : str26Path of the test split.2728test_size : float, 0.0 ~ 1.0, default 0.129Size of the test split.3031random_state : int, default 123432Seed for the random split.3334encoding : str, default 'utf-8'35Encoding for reading and writing the file.3637Returns38-------39count_train, count_test : int40Number of record in the training and test set.41"""42random.seed(random_state)4344# accumulate the number of records in the training and test set45count_train = 046count_test = 047train_range = 1 - test_size4849with open(input_path, encoding=encoding) as f_in, \50open(output_path_train, 'w', encoding=encoding) as f_train, \51open(output_path_test, 'w', encoding=encoding) as f_test:5253for line in f_in:54random_num = random.random()55if random_num < train_range:56f_train.write(line)57count_train += 158else:59f_test.write(line)60count_test += 16162return count_train, count_test636465