Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/indexed_datasets.py
694 views
1
import pickle
2
from copy import deepcopy
3
4
import numpy as np
5
6
7
class IndexedDataset:
8
def __init__(self, path, num_cache=1):
9
super().__init__()
10
self.path = path
11
self.data_file = None
12
self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets']
13
self.data_file = open(f"{path}.data", 'rb', buffering=-1)
14
self.cache = []
15
self.num_cache = num_cache
16
17
def check_index(self, i):
18
if i < 0 or i >= len(self.data_offsets) - 1:
19
raise IndexError('index out of range')
20
21
def __del__(self):
22
if self.data_file:
23
self.data_file.close()
24
25
def __getitem__(self, i):
26
self.check_index(i)
27
if self.num_cache > 0:
28
for c in self.cache:
29
if c[0] == i:
30
return c[1]
31
self.data_file.seek(self.data_offsets[i])
32
b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i])
33
item = pickle.loads(b)
34
if self.num_cache > 0:
35
self.cache = [(i, deepcopy(item))] + self.cache[:-1]
36
return item
37
38
def __len__(self):
39
return len(self.data_offsets) - 1
40
41
class IndexedDatasetBuilder:
42
def __init__(self, path):
43
self.path = path
44
self.out_file = open(f"{path}.data", 'wb')
45
self.byte_offsets = [0]
46
47
def add_item(self, item):
48
s = pickle.dumps(item)
49
bytes = self.out_file.write(s)
50
self.byte_offsets.append(self.byte_offsets[-1] + bytes)
51
52
def finalize(self):
53
self.out_file.close()
54
np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets})
55
56
57
if __name__ == "__main__":
58
import random
59
from tqdm import tqdm
60
ds_path = '/tmp/indexed_ds_example'
61
size = 100
62
items = [{"a": np.random.normal(size=[10000, 10]),
63
"b": np.random.normal(size=[10000, 10])} for i in range(size)]
64
builder = IndexedDatasetBuilder(ds_path)
65
for i in tqdm(range(size)):
66
builder.add_item(items[i])
67
builder.finalize()
68
ds = IndexedDataset(ds_path)
69
for i in tqdm(range(10000)):
70
idx = random.randint(0, size - 1)
71
assert (ds[idx]['a'] == items[idx]['a']).all()
72
73