Path: blob/main/apex/tests/distributed/amp_master_params/compare.py
1290 views
import torch12model_params_rank0 = torch.load("rank0model.pth",3map_location = lambda storage, loc: storage.cuda(0))4model_params_rank1 = torch.load("rank1model.pth",5map_location = lambda storage, loc: storage.cuda(0))6master_params_rank0 = torch.load("rank0master.pth",7map_location = lambda storage, loc: storage.cuda(0))8master_params_rank1 = torch.load("rank1master.pth",9map_location = lambda storage, loc: storage.cuda(0))1011for model_rank0, model_rank1, master_rank0, master_rank1 in zip(12model_params_rank0,13model_params_rank1,14master_params_rank0,15master_params_rank1):16assert torch.allclose(model_rank0, model_rank1), "Model param mismatch"17assert torch.allclose(master_rank0, master_rank1), "Master param mismatch"18# Some debugging/investigation assistance code:19# maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0)20# offending_val_half = model_rank0.view(-1)[maxind.item()]21# offending_val_float = master_rank0.view(-1)[maxind.item()]22# print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),23# offending_val_float.half().item())24# rtol needs to be > 2^-11 because of denormals...25assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch"2627print("OK: Model and master params match across ranks.")282930