Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Update for SuperLearner::CVFolds function that enables stratification by outcome and cluster ID
- CVFolds2 <- function (N, id, Y, cvControl) {
- if (!is.null(cvControl$validRows)) {
- return(cvControl$validRows)
- }
- stratifyCV <- cvControl$stratifyCV
- shuffle <- cvControl$shuffle
- V <- cvControl$V
- if (!stratifyCV) { ### Not Stratified
- if (shuffle) { ## Not Stratified, Shuffled
- if (is.null(id)) { #Not stratified, Shuffled, Not by ID
- validRows <- split(sample(1:N), rep(1:V, length = N))
- }
- else { #Not stratified, Shuffled, by ID
- n.id <- length(unique(id))
- id.split <- split(sample(1:n.id), rep(1:V, length = n.id))
- validRows <- vector("list", V)
- for (v in seq(V)) {
- validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
- }
- }
- }
- else { ## Not Stratified, Not Shuffled
- if (is.null(id)) { #Not Stratified, Not Shuffled, Not by ID
- validRows <- split(1:N, rep(1:V, length = N))
- }
- else { #Not Stratified, Not Shuffled, by ID
- n.id <- length(unique(id))
- id.split <- split(1:n.id, rep(1:V, length = n.id))
- validRows <- vector("list", V)
- for (v in seq(V)) {
- validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
- }
- }
- }
- }
- else { ### Stratified
- if (length(unique(Y)) != 2) {
- stop("stratifyCV only implemented for binary Y")
- }
- if (sum(Y) < V | sum(!Y) < V) {
- stop("number of (Y=1) or (Y=0) is less than the number of folds")
- }
- if (shuffle) { ## Stratified, Shuffled
- if (is.null(id)) { #Stratified, Shuffled, not by ID
- wiY0 <- which(Y == 0)
- wiY1 <- which(Y == 1)
- rowsY0 <- split(sample(wiY0), rep(1:V, length = length(wiY0)))
- rowsY1 <- split(sample(wiY1), rep(1:V, length = length(wiY1)))
- validRows <- vector("list", length = V)
- names(validRows) <- paste(seq(V))
- for (vv in seq(V)) {
- validRows[[vv]] <- c(rowsY0[[vv]], rowsY1[[vv]])
- }
- }
- else { #Stratified, Shuffled, by ID
- within.split <- suppressWarnings(tapply(1:N,
- INDEX = Y, FUN = split, 1))
- id.Y1 <- unique(id[within.split[[2]]])
- id.notY1 <- setdiff(unique(id),id.Y1)
- n.id.Y1 <- length(id.Y1)
- n.id.notY1 <- length(id.notY1)
- id.Y1.split <- split(sample(1:n.id.Y1), rep(1:V, length = n.id.Y1))
- id.notY1.split <- split(sample(1:n.id.notY1), rep(1:V, length = n.id.notY1))
- validRows <- vector("list", V)
- for (v in seq(V)) {
- validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
- which(id %in% id.notY1[id.notY1.split[[v]]]))
- }
- }
- }
- else { ## Stratified, Not Shuffled
- if (is.null(id)) {
- within.split <- suppressWarnings(tapply(1:N,
- INDEX = Y, FUN = split, rep(1:V)))
- validRows <- vector("list", length = V)
- names(validRows) <- paste(seq(V))
- for (vv in seq(V)) {
- validRows[[vv]] <- c(within.split[[1]][[vv]],
- within.split[[2]][[vv]])
- }
- }
- else { #Stratified, Not Shuffled, by ID
- within.split <- suppressWarnings(tapply(1:N,
- INDEX = Y, FUN = split, 1))
- id.Y1 <- unique(id[within.split[[2]]])
- id.notY1 <- setdiff(unique(id),id.Y1)
- n.id.Y1 <- length(id.Y1)
- n.id.notY1 <- length(id.notY1)
- id.Y1.split <- split(1:n.id.Y1, rep(1:V, length = n.id.Y1))
- id.notY1.split <- split(1:n.id.notY1, rep(1:V, length = n.id.notY1))
- validRows <- vector("list", V)
- for (v in seq(V)) {
- validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
- which(id %in% id.notY1[id.notY1.split[[v]]]))
- }
- }
- }
- }
- return(validRows)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement