Path: blob/master/labml_nn/cfr/infoset_saver.py
4959 views
import json1import pathlib2from typing import Dict34from labml import experiment5from labml_nn.cfr import InfoSet678class InfoSetSaver(experiment.ModelSaver):9def __init__(self, infosets: Dict[str, InfoSet]):10self.infosets = infosets1112def save(self, checkpoint_path: pathlib.Path) -> any:13data = {key: infoset.to_dict() for key, infoset in self.infosets.items()}14file_name = f"infosets.json"1516with open(str(checkpoint_path / file_name), 'w') as f:17f.write(json.dumps(data))1819return file_name2021def load(self, checkpoint_path: pathlib.Path, file_name: str):22with open(str(checkpoint_path / file_name), 'w') as f:23data = json.loads(f.read())2425for key, d in data.items():26self.infosets[key] = InfoSet.from_dict(d)272829