SHARE
TWEET

Untitled

a guest Oct 19th, 2019 89 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class SampleForwardChainCV(object):
  2.     def __init__(self,
  3.                  dates,
  4.                  obs_count,
  5.                  min_start_date=None,
  6.                  max_end_date=None,
  7.                  n_min_train_obs=20,
  8.                  n_min_validate_obs=20,
  9.                  n_min_test_obs=20):
  10.  
  11.         self.dates = sorted(list(set(dates)))
  12.         self.obs_count = obs_count
  13.         self.min_start_date = min_start_date or self.dates[0]
  14.         self.max_end_date = max_end_date or self.dates[-1]
  15.         self.n_min_train_obs = n_min_train_obs
  16.         self.n_min_validate_obs = n_min_validate_obs
  17.         self.n_min_test_obs = n_min_test_obs
  18.  
  19.     def split(self):
  20.  
  21.         def get_end_index(start_index, i, min_obs):
  22.             """Return the next index
  23.                given the sum of obs in
  24.                each element exceed the min
  25.                required
  26.             """
  27.             n_obs = sum(self.obs_count[start_index:i + 1])
  28.             while n_obs < min_obs:
  29.                 if self.dates[i] >= self.max_end_date:
  30.                     raise IndexError()
  31.                 i += 1
  32.                 n_obs = sum(self.obs_count[start_index:i + 1])
  33.             return i
  34.  
  35.         # Start at this index
  36.         train_start_index = self.dates.index(self.min_start_date)
  37.        
  38.         # Iterate over all dates
  39.         for i in range(train_start_index, len(self.dates)):
  40.             try:
  41.                 train_end_index = get_end_index(train_start_index, i, self.n_min_train_obs)
  42.  
  43.                 valid_start_index = train_end_index + 1
  44.                 valid_end_index = get_end_index(valid_start_index,
  45.                                                 valid_start_index,
  46.                                                 self.n_min_validate_obs)
  47.  
  48.                 test_start_index = valid_end_index + 1
  49.                 test_end_index = get_end_index(test_start_index,
  50.                                                test_start_index,
  51.                                                self.n_min_test_obs)
  52.  
  53.                 yield (self.dates[train_start_index], self.dates[train_end_index]), \
  54.                       (self.dates[valid_start_index], self.dates[valid_end_index]), \
  55.                       (self.dates[test_start_index], self.dates[test_end_index])
  56.             except:
  57.                 break
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top