Guest User

Untitled

a guest
May 26th, 2018
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.38 KB | None | 0 0
  1. import unittest
  2. import tempfile
  3. import os
  4.  
  5. import torch
  6.  
  7.  
  8. def models_eq(a, b):
  9. """
  10. Not the most elegant method. Compares two models' parameters.
  11. """
  12. sa = a.state_dict()
  13. sb = b.state_dict()
  14.  
  15. sak = set(sa.keys())
  16. sbk = set(sb.keys())
  17.  
  18. eq = len(sak) == len(sbk)
  19.  
  20. if not eq:
  21. return eq
  22.  
  23. eq = len(set.intersection(sak, sbk)) == len(sak)
  24.  
  25. if not eq:
  26. return eq
  27.  
  28. for k in sak:
  29. eq = sa[k].shape == sb[k].shape
  30.  
  31. if not eq:
  32. return eq
  33.  
  34. eq = sa[k].eq(sb[k]).all()
  35.  
  36. if not eq:
  37. return eq
  38.  
  39. return eq
  40.  
  41.  
  42. class TestCheckpoints(unittest.TestCase):
  43. def test_save_load(self):
  44. shape = (3, 3)
  45. expected = torch.nn.Linear(*shape)
  46. actual = torch.nn.Linear(*shape)
  47.  
  48. _, path = tempfile.mkstemp()
  49.  
  50. try:
  51. assert models_eq(expected, actual) == False, "Models are randomly initialized and should not be equal."
  52.  
  53. torch.save(expected.state_dict(), path)
  54. state_dict = torch.load(path)
  55. actual.load_state_dict(state_dict)
  56.  
  57. assert models_eq(expected, actual) == True, "Models should be equal after loading checkpoint."
  58. except AssertionError as e:
  59. raise e
  60. except:
  61. raise RuntimeError('Test failed!')
  62. finally:
  63. os.remove(path)
  64.  
  65.  
  66. if __name__ == '__main__':
  67. unittest.main()
Add Comment
Please, Sign In to add comment