Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/seq2seq/translation_mt5/translation_utils.py
2593 views
1
import os
2
import tarfile
3
import zipfile
4
import requests
5
import subprocess
6
from tqdm import tqdm
7
from urllib.parse import urlparse
8
9
10
def download_file(url: str, directory: str):
11
"""
12
Download the file at ``url`` to ``directory``.
13
Extract to the file content ``directory`` if the original file
14
is a tar, tar.gz or zip file.
15
16
Parameters
17
----------
18
url : str
19
url of the file.
20
21
directory : str
22
Directory to download the file.
23
"""
24
response = requests.get(url, stream=True)
25
response.raise_for_status()
26
27
content_len = response.headers.get('Content-Length')
28
total = int(content_len) if content_len is not None else 0
29
30
os.makedirs(directory, exist_ok=True)
31
file_name = get_file_name_from_url(url)
32
file_path = os.path.join(directory, file_name)
33
34
with tqdm(unit='B', total=total) as pbar, open(file_path, 'wb') as f:
35
for chunk in response.iter_content(chunk_size=1024):
36
if chunk:
37
pbar.update(len(chunk))
38
f.write(chunk)
39
40
extract_compressed_file(file_path, directory)
41
42
43
def extract_compressed_file(compressed_file_path: str, directory: str):
44
"""
45
Extract a compressed file to ``directory``. Supports zip, tar.gz, tgz,
46
tar extensions.
47
48
Parameters
49
----------
50
compressed_file_path : str
51
52
directory : str
53
File will to extracted to this directory.
54
"""
55
basename = os.path.basename(compressed_file_path)
56
if 'zip' in basename:
57
with zipfile.ZipFile(compressed_file_path, "r") as zip_f:
58
zip_f.extractall(directory)
59
elif 'tar.gz' in basename or 'tgz' in basename:
60
with tarfile.open(compressed_file_path) as f:
61
f.extractall(directory)
62
63
64
def get_file_name_from_url(url: str) -> str:
65
"""
66
Return the file_name from a URL
67
68
Parameters
69
----------
70
url : str
71
URL to extract file_name from
72
73
Returns
74
-------
75
file_name : str
76
"""
77
parse = urlparse(url)
78
return os.path.basename(parse.path)
79
80
81
82
def create_translation_data(
83
source_input_path: str,
84
target_input_path: str,
85
output_path: str,
86
delimiter: str = "\t",
87
encoding: str = "utf-8"
88
):
89
"""
90
Creates the paired source and target dataset from the separated ones.
91
e.g. creates `train.tsv` from `train.de` and `train.en`
92
"""
93
with open(source_input_path, encoding=encoding) as f_source_in, \
94
open(target_input_path, encoding=encoding) as f_target_in, \
95
open(output_path, "w", encoding=encoding) as f_out:
96
97
for source_raw in f_source_in:
98
source_raw = source_raw.strip()
99
target_raw = f_target_in.readline().strip()
100
if source_raw and target_raw:
101
output_line = source_raw + delimiter + target_raw + "\n"
102
f_out.write(output_line)
103
104