Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.24 KB | None | 0 0
  1. class SampleForwardChainCV(object):
  2. def __init__(self,
  3. dates,
  4. obs_count,
  5. min_start_date=None,
  6. max_end_date=None,
  7. n_min_train_obs=20,
  8. n_min_validate_obs=20,
  9. n_min_test_obs=20):
  10.  
  11. self.dates = sorted(list(set(dates)))
  12. self.obs_count = obs_count
  13. self.min_start_date = min_start_date or self.dates[0]
  14. self.max_end_date = max_end_date or self.dates[-1]
  15. self.n_min_train_obs = n_min_train_obs
  16. self.n_min_validate_obs = n_min_validate_obs
  17. self.n_min_test_obs = n_min_test_obs
  18.  
  19. def split(self):
  20.  
  21. def get_end_index(start_index, i, min_obs):
  22. """Return the next index
  23. given the sum of obs in
  24. each element exceed the min
  25. required
  26. """
  27. n_obs = sum(self.obs_count[start_index:i + 1])
  28. while n_obs < min_obs:
  29. if self.dates[i] >= self.max_end_date:
  30. raise IndexError()
  31. i += 1
  32. n_obs = sum(self.obs_count[start_index:i + 1])
  33. return i
  34.  
  35. # Start at this index
  36. train_start_index = self.dates.index(self.min_start_date)
  37.  
  38. # Iterate over all dates
  39. for i in range(train_start_index, len(self.dates)):
  40. try:
  41. train_end_index = get_end_index(train_start_index, i, self.n_min_train_obs)
  42.  
  43. valid_start_index = train_end_index + 1
  44. valid_end_index = get_end_index(valid_start_index,
  45. valid_start_index,
  46. self.n_min_validate_obs)
  47.  
  48. test_start_index = valid_end_index + 1
  49. test_end_index = get_end_index(test_start_index,
  50. test_start_index,
  51. self.n_min_test_obs)
  52.  
  53. yield (self.dates[train_start_index], self.dates[train_end_index]), \
  54. (self.dates[valid_start_index], self.dates[valid_end_index]), \
  55. (self.dates[test_start_index], self.dates[test_end_index])
  56. except:
  57. break
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement