Advertisement
Guest User

Untitled

a guest
May 6th, 2015
247
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.92 KB | None | 0 0
  1. kmeanSVM.train = function(x,y,n,C,gamma)
  2. {
  3. # Cluster data points
  4. kmeans.result = stats::kmeans(x, centers = n, iter.max = nrow(x)/10)
  5. centers = kmeans.result$centers
  6. cluster = kmeans.result$cluster
  7.  
  8. Model = vector(n,mode = 'list')
  9. nullind = NULL
  10. for (i in 1:n)
  11. {
  12. ind = which(cluster==i)
  13.  
  14. # If there's no data falling in this region, keep Model[[i]] = NULL
  15. if (length(ind)==0)
  16. {
  17. nullind = c(nullind,i)
  18. next
  19. }
  20.  
  21. # Split x and y
  22. cx = x[ind,]
  23. cy = y[ind]
  24.  
  25. # no need to train when cy only has one value
  26. if (length(unique(cy))==1)
  27. Model[[i]] = as.numeric(cy[1])
  28. else
  29. {
  30. Model[[i]] = e1071::svm(x = cx, y = cy, kernel = "radial",
  31. cost = C, gamma = gamma)
  32. }
  33. }
  34.  
  35. # Delete useless centers
  36. if (length(nullind)>0)
  37. {
  38. Model = Model[-nullind]
  39. centers = centers[-nullind,]
  40. }
  41. kmeanSVM.learner = list(Model = Model,
  42. centers = centers,
  43. levels = levels(y))
  44. return(structure(kmeanSVM.learner,class='kmeanSVM.learner'))
  45. }
  46.  
  47. kmeanSVM.predict = function(learner, newdata, ...)
  48. {
  49. if (class(learner)!='kmeanSVM.learner')
  50. stop("The learner should have class of 'kmeansSVM.learner'")
  51. Model = learner$Model
  52. centers = learner$centers
  53. n = length(Model)
  54.  
  55. # check data type
  56. newdata = as.matrix(newdata)
  57. if (!is.matrix(newdata))
  58. stop('Input data must be a numeric matrix or an object that can be
  59. coerced to such a matrix.')
  60.  
  61. # Get cluster label for new data
  62. pred.kmeans = kmeans(newdata, centers)
  63. cluster = pred.kmeans$cluster
  64.  
  65. y = rep(0,nrow(newdata))
  66. for (i in 1:n)
  67. {
  68. ind = which(cluster==i)
  69. if (length(ind)==0) next
  70. if (class(Model[[i]])!='svm')
  71. y[ind] = Model[[i]]
  72. else
  73. {
  74. cx = x[ind,]
  75. y[ind] = predict(Model[[i]],cx,...)
  76. }
  77. }
  78. y = factor(y)
  79. levels(y) = learner$levels
  80. return(y)
  81. }
  82.  
  83. # =========================================
  84. # Test on Breast Cancer data
  85. # =========================================
  86.  
  87. require(MASS)
  88. data(biopsy)
  89. x = biopsy[,2:10]
  90. y = biopsy[,11]
  91.  
  92. # Split data
  93. ind = which(!complete.cases(x))
  94. x = x[-ind,]
  95. y = y[-ind]
  96. set.seed(1024)
  97. train.ind = sample(nrow(x),300)
  98.  
  99. # Training session
  100. trained.learner = kmeanSVM.train(x = x[train.ind,], y = y[train.ind],
  101. n = 3, C = 1, gamma = 1)
  102. # Predict session
  103. prediction = kmeanSVM.predict(learner = trained.learner,
  104. newdata = x[-train.ind,])
  105. # Check result
  106. table(prediction,y[-train.ind])
  107.  
  108. # Compare with single svm
  109. svm.model = e1071::svm(x = x[train.ind,], y = y[train.ind], C = 1, gamma = 1)
  110. svm.pred = predict(svm.model, x[-train.ind,])
  111. table(svm.pred,y[-train.ind])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement