Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aamini
GitHub Repository: aamini/introtodeeplearning
Path: blob/master/mitdeeplearning/lab1.py
547 views
1
import os
2
import regex as re
3
import subprocess
4
import urllib
5
import numpy as np
6
import tensorflow as tf
7
8
from IPython.display import Audio
9
10
11
cwd = os.path.dirname(__file__)
12
13
14
def load_training_data():
15
with open(os.path.join(cwd, "data", "irish.abc"), "r") as f:
16
text = f.read()
17
songs = extract_song_snippet(text)
18
return songs
19
20
21
def extract_song_snippet(text):
22
pattern = "(^|\n\n)(.*?)\n\n"
23
search_results = re.findall(pattern, text, overlapped=True, flags=re.DOTALL)
24
songs = [song[1] for song in search_results]
25
print("Found {} songs in text".format(len(songs)))
26
return songs
27
28
29
def save_song_to_abc(song, filename="tmp"):
30
save_name = "{}.abc".format(filename)
31
with open(save_name, "w") as f:
32
f.write(song)
33
return filename
34
35
36
def abc2wav(abc_file):
37
path_to_tool = os.path.join(cwd, "bin", "abc2wav")
38
cmd = "{} {}".format(path_to_tool, abc_file)
39
return os.system(cmd)
40
41
42
def play_wav(wav_file):
43
return Audio(wav_file)
44
45
46
def play_song(song):
47
basename = save_song_to_abc(song)
48
ret = abc2wav(basename + ".abc")
49
if ret == 0: # did not suceed
50
return play_wav(basename + ".wav")
51
return None
52
53
54
def play_generated_song(generated_text):
55
songs = extract_song_snippet(generated_text)
56
if len(songs) == 0:
57
print(
58
"No valid songs found in generated text. Try training the \
59
model longer or increasing the amount of generated music to \
60
ensure complete songs are generated!"
61
)
62
63
for song in songs:
64
play_song(song)
65
print(
66
"None of the songs were valid, try training longer to improve \
67
syntax."
68
)
69
70
71
def test_batch_func_types(func, args):
72
ret = func(*args)
73
assert len(ret) == 2, "[FAIL] get_batch must return two arguments (input and label)"
74
assert type(ret[0]) == np.ndarray, "[FAIL] test_batch_func_types: x is not np.array"
75
assert type(ret[1]) == np.ndarray, "[FAIL] test_batch_func_types: y is not np.array"
76
print("[PASS] test_batch_func_types")
77
return True
78
79
80
def test_batch_func_shapes(func, args):
81
dataset, seq_length, batch_size = args
82
x, y = func(*args)
83
correct = (batch_size, seq_length)
84
assert (
85
x.shape == correct
86
), "[FAIL] test_batch_func_shapes: x {} is not correct shape {}".format(
87
x.shape, correct
88
)
89
assert (
90
y.shape == correct
91
), "[FAIL] test_batch_func_shapes: y {} is not correct shape {}".format(
92
y.shape, correct
93
)
94
print("[PASS] test_batch_func_shapes")
95
return True
96
97
98
def test_batch_func_next_step(func, args):
99
x, y = func(*args)
100
assert (
101
x[:, 1:] == y[:, :-1]
102
).all(), "[FAIL] test_batch_func_next_step: x_{t} must equal y_{t-1} for all t"
103
print("[PASS] test_batch_func_next_step")
104
return True
105
106
107
def test_custom_dense_layer_output(y):
108
# define the ground truth value for the array
109
true_y = np.array([[0.27064407, 0.1826951, 0.50374055]], dtype="float32")
110
assert tf.shape(y).numpy().tolist() == list(
111
true_y.shape
112
), "[FAIL] output is of incorrect shape. expected {} but got {}".format(
113
true_y.shape, y.numpy().shape
114
)
115
np.testing.assert_almost_equal(
116
y.numpy(),
117
true_y,
118
decimal=7,
119
err_msg="[FAIL] output is of incorrect value. expected {} but got {}".format(
120
true_y, y.numpy()
121
),
122
verbose=True,
123
)
124
print("[PASS] test_custom_dense_layer_output")
125
return True
126
127