Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
LukeTonin
GitHub Repository: LukeTonin/keras-seq-2-seq-signal-prediction
Path: blob/master/utils.py
58 views
1
# This file contains code modified licensed under the MIT License:
2
# Copyright (c) 2017 Guillaume Chevalier # For more information, visit:
3
# https://github.com/guillaume-chevalier/seq2seq-signal-prediction
4
# https://github.com/guillaume-chevalier/seq2seq-signal-prediction/blob/master/LICENSE
5
6
"""Contains functions to generate artificial data for predictions as well as
7
a function to plot predictions."""
8
9
import numpy as np
10
from matplotlib import pyplot as plt
11
12
def random_sine(batch_size, steps_per_epoch,
13
input_sequence_length, target_sequence_length,
14
min_frequency=0.1, max_frequency=10,
15
min_amplitude=0.1, max_amplitude=1,
16
min_offset=-0.5, max_offset=0.5,
17
num_signals=3, seed=43):
18
"""Produce a batch of signals.
19
20
The signals are the sum of randomly generated sine waves.
21
22
Arguments
23
---------
24
batch_size: Number of signals to produce.
25
steps_per_epoch: Number of batches of size batch_size produced by the
26
generator.
27
input_sequence_length: Length of the input signals to produce.
28
target_sequence_length: Length of the target signals to produce.
29
min_frequency: Minimum frequency of the base signals that are summed.
30
max_frequency: Maximum frequency of the base signals that are summed.
31
min_amplitude: Minimum amplitude of the base signals that are summed.
32
max_amplitude: Maximum amplitude of the base signals that are summed.
33
min_offset: Minimum offset of the base signals that are summed.
34
max_offset: Maximum offset of the base signals that are summed.
35
num_signals: Number of signals that are summed together.
36
seed: The seed used for generating random numbers
37
38
Returns
39
-------
40
signals: 2D array of shape (batch_size, sequence_length)
41
"""
42
num_points = input_sequence_length + target_sequence_length
43
x = np.arange(num_points) * 2*np.pi/30
44
45
while True:
46
# Reset seed to obtain same sequences from epoch to epoch
47
np.random.seed(seed)
48
49
for _ in range(steps_per_epoch):
50
signals = np.zeros((batch_size, num_points))
51
for _ in range(num_signals):
52
# Generate random amplitude, frequence, offset, phase
53
amplitude = (np.random.rand(batch_size, 1) *
54
(max_amplitude - min_amplitude) +
55
min_amplitude)
56
frequency = (np.random.rand(batch_size, 1) *
57
(max_frequency - min_frequency) +
58
min_frequency)
59
offset = (np.random.rand(batch_size, 1) *
60
(max_offset - min_offset) +
61
min_offset)
62
phase = np.random.rand(batch_size, 1) * 2 * np.pi
63
64
65
signals += amplitude * np.sin(frequency * x + phase)
66
signals = np.expand_dims(signals, axis=2)
67
68
encoder_input = signals[:, :input_sequence_length, :]
69
decoder_output = signals[:, input_sequence_length:, :]
70
71
# The output of the generator must be ([encoder_input, decoder_input], [decoder_output])
72
decoder_input = np.zeros((decoder_output.shape[0], decoder_output.shape[1], 1))
73
yield ([encoder_input, decoder_input], decoder_output)
74
75
def plot_prediction(x, y_true, y_pred):
76
"""Plots the predictions.
77
78
Arguments
79
---------
80
x: Input sequence of shape (input_sequence_length,
81
dimension_of_signal)
82
y_true: True output sequence of shape (input_sequence_length,
83
dimension_of_signal)
84
y_pred: Predicted output sequence (input_sequence_length,
85
dimension_of_signal)
86
"""
87
88
plt.figure(figsize=(12, 3))
89
90
output_dim = x.shape[-1]
91
for j in range(output_dim):
92
past = x[:, j]
93
true = y_true[:, j]
94
pred = y_pred[:, j]
95
96
label1 = "Seen (past) values" if j==0 else "_nolegend_"
97
label2 = "True future values" if j==0 else "_nolegend_"
98
label3 = "Predictions" if j==0 else "_nolegend_"
99
100
plt.plot(range(len(past)), past, "o--b",
101
label=label1)
102
plt.plot(range(len(past),
103
len(true)+len(past)), true, "x--b", label=label2)
104
plt.plot(range(len(past), len(pred)+len(past)), pred, "o--y",
105
label=label3)
106
plt.legend(loc='best')
107
plt.title("Predictions v.s. true values")
108
plt.show()
109
110
if __name__ == '__main__':
111
112
# This is an example of the plot function and the signal generator
113
from matplotlib import pyplot as plt
114
gen = random_sine(3, 3, 15, 15)
115
for i, data in enumerate(gen):
116
input_seq, output_seq = data
117
for j in range(input_seq.shape[0]):
118
plot_prediction(input_seq[j, :, :],
119
output_seq[j, :, :],
120
output_seq[j, :, :])
121
if i > 2:
122
break
123
124