Path: blob/main/a3/utils/general_utils.py
995 views
#!/usr/bin/env python31# -*- coding: utf-8 -*-2"""3CS224N 2021-2022: Homework 34general_utils.py: General purpose utilities.5Sahil Chopra <[email protected]>6"""78import numpy as np91011def get_minibatches(data, minibatch_size, shuffle=True):12"""13Iterates through the provided data one minibatch at at time. You can use this function to14iterate through data in minibatches as follows:1516for inputs_minibatch in get_minibatches(inputs, minibatch_size):17...1819Or with multiple data sources:2021for inputs_minibatch, labels_minibatch in get_minibatches([inputs, labels], minibatch_size):22...2324Args:25data: there are two possible values:26- a list or numpy array27- a list where each element is either a list or numpy array28minibatch_size: the maximum number of items in a minibatch29shuffle: whether to randomize the order of returned data30Returns:31minibatches: the return value depends on data:32- If data is a list/array it yields the next minibatch of data.33- If data a list of lists/arrays it returns the next minibatch of each element in the34list. This can be used to iterate through multiple data sources35(e.g., features and labels) at the same time.3637"""38list_data = type(data) is list and (type(data[0]) is list or type(data[0]) is np.ndarray)39data_size = len(data[0]) if list_data else len(data)40indices = np.arange(data_size)41if shuffle:42np.random.shuffle(indices)43for minibatch_start in np.arange(0, data_size, minibatch_size):44minibatch_indices = indices[minibatch_start:minibatch_start + minibatch_size]45yield [_minibatch(d, minibatch_indices) for d in data] if list_data \46else _minibatch(data, minibatch_indices)474849def _minibatch(data, minibatch_idx):50return data[minibatch_idx] if type(data) is np.ndarray else [data[i] for i in minibatch_idx]515253def test_all_close(name, actual, expected):54if actual.shape != expected.shape:55raise ValueError("{:} failed, expected output to have shape {:} but has shape {:}"56.format(name, expected.shape, actual.shape))57if np.amax(np.fabs(actual - expected)) > 1e-6:58raise ValueError("{:} failed, expected {:} but value is {:}".format(name, expected, actual))59else:60print(name, "passed!")616263