Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/cfr/infoset_saver.py
4959 views
1
import json
2
import pathlib
3
from typing import Dict
4
5
from labml import experiment
6
from labml_nn.cfr import InfoSet
7
8
9
class InfoSetSaver(experiment.ModelSaver):
10
def __init__(self, infosets: Dict[str, InfoSet]):
11
self.infosets = infosets
12
13
def save(self, checkpoint_path: pathlib.Path) -> any:
14
data = {key: infoset.to_dict() for key, infoset in self.infosets.items()}
15
file_name = f"infosets.json"
16
17
with open(str(checkpoint_path / file_name), 'w') as f:
18
f.write(json.dumps(data))
19
20
return file_name
21
22
def load(self, checkpoint_path: pathlib.Path, file_name: str):
23
with open(str(checkpoint_path / file_name), 'w') as f:
24
data = json.loads(f.read())
25
26
for key, d in data.items():
27
self.infosets[key] = InfoSet.from_dict(d)
28
29