Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import unittest
- import tempfile
- import os
- import torch
- def models_eq(a, b):
- """
- Not the most elegant method. Compares two models' parameters.
- """
- sa = a.state_dict()
- sb = b.state_dict()
- sak = set(sa.keys())
- sbk = set(sb.keys())
- eq = len(sak) == len(sbk)
- if not eq:
- return eq
- eq = len(set.intersection(sak, sbk)) == len(sak)
- if not eq:
- return eq
- for k in sak:
- eq = sa[k].shape == sb[k].shape
- if not eq:
- return eq
- eq = sa[k].eq(sb[k]).all()
- if not eq:
- return eq
- return eq
- class TestCheckpoints(unittest.TestCase):
- def test_save_load(self):
- shape = (3, 3)
- expected = torch.nn.Linear(*shape)
- actual = torch.nn.Linear(*shape)
- _, path = tempfile.mkstemp()
- try:
- assert models_eq(expected, actual) == False, "Models are randomly initialized and should not be equal."
- torch.save(expected.state_dict(), path)
- state_dict = torch.load(path)
- actual.load_state_dict(state_dict)
- assert models_eq(expected, actual) == True, "Models should be equal after loading checkpoint."
- except AssertionError as e:
- raise e
- except:
- raise RuntimeError('Test failed!')
- finally:
- os.remove(path)
- if __name__ == '__main__':
- unittest.main()
Add Comment
Please, Sign In to add comment