Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import random
- import numpy as np
- import pandas as pd
- from collections import Counter, defaultdict
- def stratified_group_k_fold(X, y, groups, k, seed=None):
- labels_num = np.max(y) + 1
- y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
- y_distr = Counter()
- for label, g in zip(y, groups):
- y_counts_per_group[g][label] += 1
- y_distr[label] += 1
- y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
- groups_per_fold = defaultdict(set)
- def eval_y_counts_per_fold(y_counts, fold):
- y_counts_per_fold[fold] += y_counts
- std_per_label = []
- for label in range(labels_num):
- label_std = np.std([y_counts_per_fold[i][label] / y_distr[label] for i in range(k)])
- std_per_label.append(label_std)
- y_counts_per_fold[fold] -= y_counts
- return np.mean(std_per_label)
- groups_and_y_counts = list(y_counts_per_group.items())
- random.Random(seed).shuffle(groups_and_y_counts)
- for g, y_counts in sorted(groups_and_y_counts, key=lambda x: -np.std(x[1])):
- best_fold = None
- min_eval = None
- for i in range(k):
- fold_eval = eval_y_counts_per_fold(y_counts, i)
- if min_eval is None or fold_eval < min_eval:
- min_eval = fold_eval
- best_fold = i
- y_counts_per_fold[best_fold] += y_counts
- groups_per_fold[best_fold].add(g)
- all_groups = set(groups)
- for i in range(k):
- train_groups = all_groups - groups_per_fold[i]
- test_groups = groups_per_fold[i]
- train_indices = [i for i, g in enumerate(groups) if g in train_groups]
- test_indices = [i for i, g in enumerate(groups) if g in test_groups]
- yield train_indices, test_indices
- x_train = pd.read_csv('../input/train/train.csv')
- y_train = train.Target.values
- groups = np.array(x_train.ID.values)
- for fold_ind, (dev_ind, val_ind) in enumerate(stratified_group_k_fold(train_x, train_y, groups, k=5)):
- y_train, y_test = y[train_idx], y[test_idx]
- x_train, x_test = groups[train_idx], groups[test_idx]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement