Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
sagemath
GitHub Repository: sagemath/sagelib
Path: blob/master/sage/media/wav.py
4079 views
1
r"""
2
Work with WAV files.
3
4
A WAV file is a header specifying format information, followed by a
5
sequence of bytes, representing the state of some audio signal over a
6
length of time.
7
8
A WAV file may have any number of channels. Typically, they have 1
9
(mono) or 2 (for stereo). The data of a WAV file is given as a
10
sequence of frames. A frame consists of samples. There is one sample
11
per channel, per frame. Every wav file has a sample width, or, the
12
number of bytes per sample. Typically this is either 1 or 2 bytes.
13
14
The wav module supplies more convenient access to this data. In
15
particular, see the docstring for \code{Wave.channel_data()}.
16
17
The header contains information necessary for playing the WAV file,
18
including the number of frames per second, the number of bytes per
19
sample, and the number of channels in the file.
20
21
AUTHORS:
22
-- Bobby Moretti and Gonzolo Tornaria (2007-07-01): First version
23
-- William Stein (2007-07-03): add more
24
-- Bobby Moretti (2007-07-03): add doctests
25
"""
26
27
import math
28
import os
29
import wave
30
31
from sage.plot.plot import list_plot
32
from sage.structure.sage_object import SageObject
33
from sage.misc.all import srange
34
from sage.misc.html import html
35
from sage.rings.all import RDF
36
37
class Wave(SageObject):
38
"""
39
A class wrapping a wave audio file.
40
41
INPUT:
42
You must call Wave() with either data = filename, where
43
filename is the name of a wave file, or with each of the
44
following options:
45
46
channels -- the number of channels in the wave file (1 for
47
mono, 2 for stereo, etc...
48
width -- the number of bytes per sample
49
framerate -- the number of frames per second
50
nframes -- the number of frames in the data stream
51
bytes -- a string object containing the bytes of the
52
data stream
53
54
Slicing:
55
Slicing a Wave object returns a new wave object that has been
56
trimmed to the bytes that you have given it.
57
58
Indexing:
59
Getting the $n$th item in a Wave object will give you the value
60
of the $n$th frame.
61
"""
62
def __init__(self, data=None, **kwds):
63
if data is not None:
64
self._filename = data
65
self._name = os.path.split(data)[1]
66
wv = wave.open(data, "rb")
67
self._nchannels = wv.getnchannels()
68
self._width = wv.getsampwidth()
69
self._framerate = wv.getframerate()
70
self._nframes = wv.getnframes()
71
self._bytes = wv.readframes(self._nframes)
72
from channels import _separate_channels
73
self._channel_data = _separate_channels(self._bytes,
74
self._width,
75
self._nchannels)
76
wv.close()
77
elif kwds:
78
try:
79
self._name = kwds['name']
80
self._nchannels = kwds['nchannels']
81
self._width = kwds['width']
82
self._framerate = kwds['framerate']
83
self._nframes = kwds['nframes']
84
self._bytes = kwds['bytes']
85
self._channel_data = kwds['channel_data']
86
except KeyError, msg:
87
raise KeyError, msg + " invalid input to Wave initializer"
88
else:
89
raise ValueError, "Must give a filename"
90
91
92
def save(self, filename='sage.wav'):
93
r"""
94
Save this wave file to disk, either as a Sage sobj or as a .wav file.
95
96
INPUT:
97
filename -- the path of the file to save. If filename ends
98
with 'wav', then save as a wave file,
99
otherwise, save a Sage object.
100
101
If no input is given, save the file as 'sage.wav'.
102
103
"""
104
if not filename.endswith('.wav'):
105
SageObject.save(self, filename)
106
return
107
wv = wave.open(filename, 'wb')
108
wv.setnchannels(self._nchannels)
109
wv.setsampwidth(self._width)
110
wv.setframerate(self._framerate)
111
wv.setnframes(self._nframes)
112
wv.writeframes(self._bytes)
113
wv.close()
114
115
def listen(self):
116
"""
117
Listen to (or download) this wave file.
118
119
Creates a link to this wave file in the notebook.
120
"""
121
from sage.misc.html import html
122
i = 0
123
fname = 'sage%s.wav'%i
124
while os.path.exists(fname):
125
i += 1
126
fname = 'sage%s.wav'%i
127
128
self.save(fname)
129
return html('<a href="cell://%s">Click to listen to %s</a>'%(fname, self._name))
130
131
def channel_data(self, n):
132
"""
133
Get the data from a given channel.
134
135
INPUT:
136
n -- the channel number to get
137
138
OUTPUT:
139
A list of signed ints, each containing the value of a frame.
140
"""
141
return self._channel_data[n]
142
143
144
def getnchannels(self):
145
"""
146
Returns the number of channels in this wave object.
147
148
OUTPUT:
149
The number of channels in this wave file.
150
"""
151
return self._nchannels
152
153
def getsampwidth(self):
154
"""
155
Returns the number of bytes per sample in this wave object.
156
157
OUTPUT:
158
The number of bytes in each sample.
159
"""
160
return self._width
161
162
def getframerate(self):
163
"""
164
Returns the number of frames per second in this wave object.
165
166
OUTPUT:
167
The frame rate of this sound file.
168
"""
169
return self._framerate
170
171
def getnframes(self):
172
"""
173
The total number of frames in this wave object.
174
175
OUTPUT:
176
The number of frames in this WAV.
177
"""
178
return self._nframes
179
180
def readframes(self, n):
181
"""
182
Reads out the raw data for the first $n$ frames of this wave
183
object.
184
185
INPUT:
186
n -- the number of frames to return
187
188
OUTPUT:
189
A list of bytes (in string form) representing the raw wav data.
190
"""
191
return self._bytes[:nframes*self._width]
192
193
def getlength(self):
194
"""
195
Returns the length of this file (in seconds).
196
197
OUTPUT:
198
The running time of the entire WAV object.
199
"""
200
return float(self._nframes) / (self._nchannels * float(self._framerate))
201
202
def _repr_(self):
203
nc = self.getnchannels()
204
return "Wave file %s with %s channel%s of length %s seconds%s" % \
205
(self._name, nc, "" if nc == 1 else "s", self.getlength(), "" if nc == 1 else " each")
206
207
def _normalize_npoints(self, npoints):
208
"""
209
Used internally while plotting to normalize the number of
210
"""
211
return npoints if npoints else self._nframes
212
213
def domain(self, npoints=None):
214
"""
215
Used internally for plotting. Get the x-values for the various points to plot.
216
"""
217
npoints = self._normalize_npoints(npoints)
218
# figure out on what intervals to sample the data
219
seconds = float(self._nframes) / float(self._width)
220
frame_duration = seconds / (float(npoints) * float(self._framerate))
221
222
domain = [n * frame_duration for n in xrange(npoints)]
223
return domain
224
225
def values(self, npoints=None, channel=0):
226
"""
227
Used internally for plotting. Get the y-values for the various points to plot.
228
"""
229
npoints = self._normalize_npoints(npoints)
230
231
# now, how many of the frames do we sample?
232
frame_skip = int(self._nframes / npoints)
233
# the values of the function at each point in the domain
234
cd = self.channel_data(channel)
235
236
# now scale the values
237
scale = float(1 << (8*self._width -1))
238
values = [cd[frame_skip*i]/scale for i in xrange(npoints)]
239
return values
240
241
def set_values(self, values, channel=0):
242
"""
243
Used internally for plotting. Get the y-values for the various points to plot.
244
"""
245
c = self.channel_data(channel)
246
npoints = len(c)
247
if len(values) != npoints:
248
raise ValueError, "values (of length %s) must have length %s"%(len(values), npoints)
249
250
# unscale the values
251
scale = float(1 << (8*self._width -1))
252
values = [float(abs(s)) * scale for s in values]
253
254
# the values of the function at each point in the domain
255
c = self.channel_data(channel)
256
for i in xrange(npoints):
257
c[i] = values[i]
258
259
def vector(self, npoints=None, channel=0):
260
npoints = self._normalize_npoints(npoints)
261
262
V = RDF**npoints
263
return V(self.values(npoints=npoints, channel=channel))
264
265
def plot(self, npoints=None, channel=0, plotjoined=True, **kwds):
266
"""
267
Plots the audio data.
268
269
INPUT:
270
npoints -- number of sample points to take; if not given, draws
271
all known points.
272
channel -- 0 or 1 (if stereo). default: 0
273
plotjoined -- whether to just draw dots or draw lines between sample points
274
275
OUTPUT:
276
a plot object that can be shown.
277
"""
278
279
domain = self.domain(npoints = npoints)
280
values = self.values(npoints=npoints, channel = channel)
281
points = zip(domain, values)
282
283
L = list_plot(points, plotjoined=plotjoined, **kwds)
284
L.xmin(0)
285
L.xmax(domain[-1])
286
return L
287
288
def plot_fft(self, npoints=None, channel=0, half=True, **kwds):
289
v = self.vector(npoints=npoints)
290
w = v.fft()
291
if half:
292
w = w[:len(w)//2]
293
z = [abs(x) for x in w]
294
if half:
295
r = math.pi
296
else:
297
r = 2*math.pi
298
data = zip(srange(0, r, r/len(z)), z)
299
L = list_plot(data, plotjoined=True, **kwds)
300
L.xmin(0)
301
L.xmax(r)
302
return L
303
304
def plot_raw(self, npoints=None, channel=0, plotjoined=True, **kwds):
305
npoints = self._normalize_npoints(npoints)
306
seconds = float(self._nframes) / float(self._width)
307
sample_step = seconds / float(npoints)
308
domain = [float(n*sample_step) / float(self._framerate) for n in xrange(npoints)]
309
frame_skip = self._nframes / npoints
310
values = [self.channel_data(channel)[frame_skip*i] for i in xrange(npoints)]
311
points = zip(domain, values)
312
313
return list_plot(points, plotjoined=plotjoined, **kwds)
314
315
def __getitem__(self, i):
316
"""
317
Returns the `i`-th frame of data in the wave, in the form of a string,
318
if `i` is an integer.
319
Returns a slice of self if `i` is a slice.
320
"""
321
if isinstance(i, slice):
322
start, stop, step = i.indices(self._nframes)
323
return self._copy(start, stop)
324
else:
325
n = i*self._width
326
return self._bytes[n:n+self._width]
327
328
def slice_seconds(self, start, stop):
329
"""
330
Slices the wave from start to stop.
331
332
INPUT:
333
start -- the time index from which to begin the slice (in seconds)
334
stop -- the time index from which to end the slice (in seconds)
335
336
OUTPUT:
337
A Wave object whose data is this object's data,
338
sliced between the given time indices
339
"""
340
start = int(start*self.getframerate())
341
stop = int(stop*self.getframerate())
342
return self[start:stop]
343
344
# start and stop are frame numbers
345
def _copy(self, start, stop):
346
start = start * self._width
347
stop = stop * self._width
348
channels_sliced = [self._channel_data[i][start:stop] for i in range(self._nchannels)]
349
print stop - start
350
351
return Wave(nchannels = self._nchannels,
352
width = self._width,
353
framerate = self._framerate,
354
bytes = self._bytes[start:stop],
355
nframes = stop - start,
356
channel_data = channels_sliced,
357
name = self._name)
358
359
def __copy__(self):
360
return self._copy(0, self._nframes)
361
362
def convolve(self, right, channel=0):
363
"""
364
NOT DONE!
365
366
Convolution of self and other, i.e., add their fft's, then
367
inverse fft back.
368
"""
369
if not isinstance(right, Wave):
370
raise TypeError, "right must be a wave"
371
npoints = self._nframes
372
v = self.vector(npoints, channel=channel).fft()
373
w = right.vector(npoints, channel=channel).fft()
374
k = v + w
375
i = k.inv_fft()
376
conv = self.__copy__()
377
conv.set_values(list(i))
378
conv._name = "convolution of %s and %s"%(self._name, right._name)
379
return conv
380
381