Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/hparams.py
694 views
1
import argparse
2
import os
3
import yaml
4
5
global_print_hparams = True
6
hparams = {}
7
8
9
class Args:
10
def __init__(self, **kwargs):
11
for k, v in kwargs.items():
12
self.__setattr__(k, v)
13
14
15
def override_config(old_config: dict, new_config: dict):
16
for k, v in new_config.items():
17
if isinstance(v, dict) and k in old_config:
18
override_config(old_config[k], new_config[k])
19
else:
20
old_config[k] = v
21
22
23
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True,reset=True,infer=True):
24
'''
25
Load hparams from multiple sources:
26
1. config chain (i.e. first load base_config, then load config);
27
2. if reset == True, load from the (auto-saved) complete config file ('config.yaml')
28
which contains all settings and do not rely on base_config;
29
3. load from argument --hparams or hparams_str, as temporary modification.
30
'''
31
if config == '':
32
parser = argparse.ArgumentParser(description='neural music')
33
parser.add_argument('--config', type=str, default='',
34
help='location of the data corpus')
35
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
36
parser.add_argument('--hparams', type=str, default='',
37
help='location of the data corpus')
38
parser.add_argument('--infer', action='store_true', help='infer')
39
parser.add_argument('--validate', action='store_true', help='validate')
40
parser.add_argument('--reset', action='store_true', help='reset hparams')
41
parser.add_argument('--debug', action='store_true', help='debug')
42
args, unknown = parser.parse_known_args()
43
else:
44
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
45
infer=infer, validate=False, reset=reset, debug=False)
46
args_work_dir = ''
47
if args.exp_name != '':
48
args.work_dir = args.exp_name
49
args_work_dir = f'checkpoints/{args.work_dir}'
50
51
config_chains = []
52
loaded_config = set()
53
54
def load_config(config_fn): # deep first
55
with open(config_fn, encoding='utf-8') as f:
56
hparams_ = yaml.safe_load(f)
57
loaded_config.add(config_fn)
58
if 'base_config' in hparams_:
59
ret_hparams = {}
60
if not isinstance(hparams_['base_config'], list):
61
hparams_['base_config'] = [hparams_['base_config']]
62
for c in hparams_['base_config']:
63
if c not in loaded_config:
64
if c.startswith('.'):
65
c = f'{os.path.dirname(config_fn)}/{c}'
66
c = os.path.normpath(c)
67
override_config(ret_hparams, load_config(c))
68
override_config(ret_hparams, hparams_)
69
else:
70
ret_hparams = hparams_
71
config_chains.append(config_fn)
72
return ret_hparams
73
74
global hparams
75
assert args.config != '' or args_work_dir != ''
76
saved_hparams = {}
77
if args_work_dir != 'checkpoints/':
78
ckpt_config_path = f'{args_work_dir}/config.yaml'
79
if os.path.exists(ckpt_config_path):
80
try:
81
with open(ckpt_config_path, encoding='utf-8') as f:
82
saved_hparams.update(yaml.safe_load(f))
83
except:
84
pass
85
if args.config == '':
86
args.config = ckpt_config_path
87
88
hparams_ = {}
89
90
hparams_.update(load_config(args.config))
91
92
if not args.reset:
93
hparams_.update(saved_hparams)
94
hparams_['work_dir'] = args_work_dir
95
96
if args.hparams != "":
97
for new_hparam in args.hparams.split(","):
98
k, v = new_hparam.split("=")
99
if k not in hparams_:
100
hparams_[k] = eval(v)
101
if v in ['True', 'False'] or type(hparams_[k]) == bool:
102
hparams_[k] = eval(v)
103
else:
104
hparams_[k] = type(hparams_[k])(v)
105
106
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
107
os.makedirs(hparams_['work_dir'], exist_ok=True)
108
with open(ckpt_config_path, 'w', encoding='utf-8') as f:
109
yaml.safe_dump(hparams_, f)
110
111
hparams_['infer'] = args.infer
112
hparams_['debug'] = args.debug
113
hparams_['validate'] = args.validate
114
global global_print_hparams
115
if global_hparams:
116
hparams.clear()
117
hparams.update(hparams_)
118
119
if print_hparams and global_print_hparams and global_hparams:
120
print('| Hparams chains: ', config_chains)
121
print('| Hparams: ')
122
for i, (k, v) in enumerate(sorted(hparams_.items())):
123
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
124
print("")
125
global_print_hparams = False
126
# print(hparams_.keys())
127
if hparams.get('exp_name') is None:
128
hparams['exp_name'] = args.exp_name
129
if hparams_.get('exp_name') is None:
130
hparams_['exp_name'] = args.exp_name
131
return hparams_
132
133