Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
machinecurve
GitHub Repository: machinecurve/extra_keras_datasets
Path: blob/master/assets/basic_template.py
153 views
1
'''
2
Import the <Dataset name>
3
Source: <URL to dataset>
4
Description: <Short dataset description>
5
6
~~~ Important note ~~~
7
Please cite the following work when using or referencing the dataset:
8
<Citation>
9
10
'''
11
12
import numpy as np
13
14
15
def load_data(path='<Dataset_slug>.npz', size='small'):
16
"""Loads the <Dataset name>
17
# Arguments
18
path: path where to cache the dataset locally
19
(relative to ~/.keras/datasets).
20
size: small or large, indicating dummy dataset size to return.
21
# Returns
22
Tuple of Numpy arrays: `(input_train, target_train),
23
(input_test, target_test)`.
24
"""
25
26
if size == 'small':
27
input_train = np.array([1, 2])
28
target_train = np.array([0, 1])
29
input_test = np.array([2, 3])
30
target_test = np.array([1, 0])
31
else:
32
input_train = np.array([1, 2, 84, 9, 1, 48, 2])
33
target_train = np.array([0, 1, 0, 0, 0, 1, 1])
34
input_test = np.array([2, 3, 32, 84, 99, 1, 2])
35
target_test = np.array([1, 0, 0, 0, 1, 0, 1])
36
37
return (input_train, target_train), (input_test, target_test)
38
39