Advertisement
Guest User

Untitled

a guest
Jun 18th, 2019
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.54 KB | None | 0 0
  1. import random
  2. import tempfile
  3. from typing import Type
  4.  
  5. from dffml.repo import Repo, RepoData
  6. from dffml.model.model import ModelConfig
  7. from dffml.source.source import Sources
  8. from dffml.source.memory import MemorySource, MemorySourceConfig
  9. from dffml.feature import Data, DefFeature, Features
  10. from dffml.util.asynctestcase import AsyncTestCase
  11.  
  12. from dffml_model_model_name.model.misc import Misc
  13.  
  14. FEATURE_DATA = [
  15. [1.1, 42393.0],
  16. [1.3, 49255.0],
  17. [1.5, 40781.0],
  18. [2.0, 46575.0],
  19. [2.2, 42941.0],
  20. [2.9, 59692.0],
  21. [3.0, 63200.0],
  22. [3.2, 57495.0],
  23. [3.2, 67495.0],
  24. [3.7, 60239.0],
  25. [3.9, 66268.0],
  26. [4.0, 58844.0],
  27. [4.0, 60007.0],
  28. [4.1, 60131.0],
  29. [4.5, 64161.0],
  30. [4.9, 70988.0],
  31. [5.1, 69079.0],
  32. [5.3, 86138.0],
  33. [5.9, 84413.0],
  34. [6.0, 96990.0],
  35. [6.8, 94788.0],
  36. [7.1, 101323.0],
  37. [7.9, 104352.0],
  38. [8.2, 116862.0],
  39. [8.7, 112481.0],
  40. [9.0, 108632.0],
  41. [9.5, 120019.0],
  42. [9.6, 115685.0],
  43. [10.3, 125441.0],
  44. [10.5, 124922.0]
  45. ]
  46.  
  47. class TestMisc(AsyncTestCase):
  48.  
  49. @classmethod
  50. def setUpClass(cls):
  51. cls.model_dir = tempfile.TemporaryDirectory()
  52. cls.model = Misc(ModelConfig(directory=cls.model_dir.name, predict='Salary'))
  53. cls.feature = DefFeature('YearsExperience', float, 1)
  54. cls.features = Features(cls.feature)
  55. cls.classifications = []
  56. YearsExperience, Salary = list(zip(*FEATURE_DATA))
  57. cls.repos = [
  58. Repo(str(i),
  59. data={'features': {
  60. 'YearsExperience': YearsExperience[i],
  61. 'Salary': Salary[i],
  62. }}
  63. )
  64. for i in range(0, len(Salary))
  65. ]
  66. cls.sources = \
  67. Sources(MemorySource(MemorySourceConfig(repos=cls.repos)))
  68.  
  69. @classmethod
  70. def tearDownClass(cls):
  71. cls.model_dir.cleanup()
  72.  
  73. async def test_context(self):
  74. async with self.sources as sources, self.features as features, \
  75. self.model as model:
  76. async with sources() as sctx, model() as mctx:
  77. # Test train
  78. await mctx.train(sctx, features,
  79. self.classifications)
  80. # Test accuracy
  81. res = await mctx.accuracy(sctx, features,
  82. self.classifications)
  83. self.assertGreater(res, 0.9)
  84. # Test predict
  85. res = [repo async for repo in mctx.predict(sctx.repos(),
  86. features, self.classifications)]
  87. self.assertEqual(len(res), 1)
  88. self.assertEqual(res[0][0].src_url, a.src_url)
  89. self.assertTrue(res[0][1])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement