Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/simplify.py
446 views
1
from argparse import ArgumentParser
2
3
import torch
4
5
6
def simplify_pth(pth_name, project_name):
7
model_path = f'./checkpoints/{project_name}'
8
checkpoint_dict = torch.load(f'{model_path}/{pth_name}')
9
torch.save({'epoch': checkpoint_dict['epoch'],
10
'state_dict': checkpoint_dict['state_dict'],
11
'global_step': None,
12
'checkpoint_callback_best': None,
13
'optimizer_states': None,
14
'lr_schedulers': None
15
}, f'./clean_{pth_name}')
16
17
18
def main():
19
parser = ArgumentParser()
20
parser.add_argument('--proj', type=str)
21
parser.add_argument('--steps', type=str)
22
args = parser.parse_args()
23
model_name = f"model_ckpt_steps_{args.steps}.ckpt"
24
simplify_pth(model_name, args.proj)
25
26
27
if __name__ == '__main__':
28
main()
29
30