Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class SampleForwardChainCV(object):
- def __init__(self,
- dates,
- obs_count,
- min_start_date=None,
- max_end_date=None,
- n_min_train_obs=20,
- n_min_validate_obs=20,
- n_min_test_obs=20):
- self.dates = sorted(list(set(dates)))
- self.obs_count = obs_count
- self.min_start_date = min_start_date or self.dates[0]
- self.max_end_date = max_end_date or self.dates[-1]
- self.n_min_train_obs = n_min_train_obs
- self.n_min_validate_obs = n_min_validate_obs
- self.n_min_test_obs = n_min_test_obs
- def split(self):
- def get_end_index(start_index, i, min_obs):
- """Return the next index
- given the sum of obs in
- each element exceed the min
- required
- """
- n_obs = sum(self.obs_count[start_index:i + 1])
- while n_obs < min_obs:
- if self.dates[i] >= self.max_end_date:
- raise IndexError()
- i += 1
- n_obs = sum(self.obs_count[start_index:i + 1])
- return i
- # Start at this index
- train_start_index = self.dates.index(self.min_start_date)
- # Iterate over all dates
- for i in range(train_start_index, len(self.dates)):
- try:
- train_end_index = get_end_index(train_start_index, i, self.n_min_train_obs)
- valid_start_index = train_end_index + 1
- valid_end_index = get_end_index(valid_start_index,
- valid_start_index,
- self.n_min_validate_obs)
- test_start_index = valid_end_index + 1
- test_end_index = get_end_index(test_start_index,
- test_start_index,
- self.n_min_test_obs)
- yield (self.dates[train_start_index], self.dates[train_end_index]), \
- (self.dates[valid_start_index], self.dates[valid_end_index]), \
- (self.dates[test_start_index], self.dates[test_end_index])
- except:
- break
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement