Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a3/utils/general_utils.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
CS224N 2021-2022: Homework 3
5
general_utils.py: General purpose utilities.
6
Sahil Chopra <[email protected]>
7
"""
8
9
import numpy as np
10
11
12
def get_minibatches(data, minibatch_size, shuffle=True):
13
"""
14
Iterates through the provided data one minibatch at at time. You can use this function to
15
iterate through data in minibatches as follows:
16
17
for inputs_minibatch in get_minibatches(inputs, minibatch_size):
18
...
19
20
Or with multiple data sources:
21
22
for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):
23
...
24
25
Args:
26
data: there are two possible values:
27
- a list or numpy array
28
- a list where each element is either a list or numpy array
29
minibatch_size: the maximum number of items in a minibatch
30
shuffle: whether to randomize the order of returned data
31
Returns:
32
minibatches: the return value depends on data:
33
- If data is a list/array it yields the next minibatch of data.
34
- If data a list of lists/arrays it returns the next minibatch of each element in the
35
list. This can be used to iterate through multiple data sources
36
(e.g., features and labels) at the same time.
37
38
"""
39
list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)
40
data_size = len(data[0]) if list_data else len(data)
41
indices = np.arange(data_size)
42
if shuffle:
43
np.random.shuffle(indices)
44
for minibatch_start in np.arange(0, data_size, minibatch_size):
45
minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]
46
yield [_minibatch(d, minibatch_indices) for d in data] if list_data \
47
else _minibatch(data, minibatch_indices)
48
49
50
def _minibatch(data, minibatch_idx):
51
return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]
52
53
54
def test_all_close(name, actual, expected):
55
if actual.shape != expected.shape:
56
raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"
57
.format(name, expected.shape, actual.shape))
58
if np.amax(np.fabs(actual - expected)) > 1e-6:
59
raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))
60
else:
61
print(name, "passed!")
62
63