Advertisement
Guest User

Untitled

a guest
Aug 31st, 2015
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.59 KB | None | 0 0
  1. # Update for SuperLearner::CVFolds function that enables stratification by outcome and cluster ID
  2.  
  3. CVFolds2 <- function (N, id, Y, cvControl) {
  4. if (!is.null(cvControl$validRows)) {
  5. return(cvControl$validRows)
  6. }
  7. stratifyCV <- cvControl$stratifyCV
  8. shuffle <- cvControl$shuffle
  9. V <- cvControl$V
  10. if (!stratifyCV) { ### Not Stratified
  11. if (shuffle) { ## Not Stratified, Shuffled
  12. if (is.null(id)) { #Not stratified, Shuffled, Not by ID
  13. validRows <- split(sample(1:N), rep(1:V, length = N))
  14. }
  15. else { #Not stratified, Shuffled, by ID
  16. n.id <- length(unique(id))
  17. id.split <- split(sample(1:n.id), rep(1:V, length = n.id))
  18. validRows <- vector("list", V)
  19. for (v in seq(V)) {
  20. validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
  21. }
  22. }
  23. }
  24. else { ## Not Stratified, Not Shuffled
  25. if (is.null(id)) { #Not Stratified, Not Shuffled, Not by ID
  26. validRows <- split(1:N, rep(1:V, length = N))
  27. }
  28. else { #Not Stratified, Not Shuffled, by ID
  29. n.id <- length(unique(id))
  30. id.split <- split(1:n.id, rep(1:V, length = n.id))
  31. validRows <- vector("list", V)
  32. for (v in seq(V)) {
  33. validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
  34. }
  35. }
  36. }
  37. }
  38. else { ### Stratified
  39. if (length(unique(Y)) != 2) {
  40. stop("stratifyCV only implemented for binary Y")
  41. }
  42. if (sum(Y) < V | sum(!Y) < V) {
  43. stop("number of (Y=1) or (Y=0) is less than the number of folds")
  44. }
  45. if (shuffle) { ## Stratified, Shuffled
  46. if (is.null(id)) { #Stratified, Shuffled, not by ID
  47. wiY0 <- which(Y == 0)
  48. wiY1 <- which(Y == 1)
  49. rowsY0 <- split(sample(wiY0), rep(1:V, length = length(wiY0)))
  50. rowsY1 <- split(sample(wiY1), rep(1:V, length = length(wiY1)))
  51. validRows <- vector("list", length = V)
  52. names(validRows) <- paste(seq(V))
  53. for (vv in seq(V)) {
  54. validRows[[vv]] <- c(rowsY0[[vv]], rowsY1[[vv]])
  55. }
  56. }
  57. else { #Stratified, Shuffled, by ID
  58. within.split <- suppressWarnings(tapply(1:N,
  59. INDEX = Y, FUN = split, 1))
  60. id.Y1 <- unique(id[within.split[[2]]])
  61. id.notY1 <- setdiff(unique(id),id.Y1)
  62. n.id.Y1 <- length(id.Y1)
  63. n.id.notY1 <- length(id.notY1)
  64. id.Y1.split <- split(sample(1:n.id.Y1), rep(1:V, length = n.id.Y1))
  65. id.notY1.split <- split(sample(1:n.id.notY1), rep(1:V, length = n.id.notY1))
  66. validRows <- vector("list", V)
  67. for (v in seq(V)) {
  68. validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
  69. which(id %in% id.notY1[id.notY1.split[[v]]]))
  70. }
  71.  
  72. }
  73. }
  74. else { ## Stratified, Not Shuffled
  75. if (is.null(id)) {
  76. within.split <- suppressWarnings(tapply(1:N,
  77. INDEX = Y, FUN = split, rep(1:V)))
  78. validRows <- vector("list", length = V)
  79. names(validRows) <- paste(seq(V))
  80. for (vv in seq(V)) {
  81. validRows[[vv]] <- c(within.split[[1]][[vv]],
  82. within.split[[2]][[vv]])
  83. }
  84. }
  85. else { #Stratified, Not Shuffled, by ID
  86. within.split <- suppressWarnings(tapply(1:N,
  87. INDEX = Y, FUN = split, 1))
  88. id.Y1 <- unique(id[within.split[[2]]])
  89. id.notY1 <- setdiff(unique(id),id.Y1)
  90. n.id.Y1 <- length(id.Y1)
  91. n.id.notY1 <- length(id.notY1)
  92. id.Y1.split <- split(1:n.id.Y1, rep(1:V, length = n.id.Y1))
  93. id.notY1.split <- split(1:n.id.notY1, rep(1:V, length = n.id.notY1))
  94. validRows <- vector("list", V)
  95. for (v in seq(V)) {
  96. validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
  97. which(id %in% id.notY1[id.notY1.split[[v]]]))
  98. }
  99. }
  100. }
  101. }
  102. return(validRows)
  103. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement