Advertisement
Dundre32

Untitled

Apr 25th, 2020
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. override def fit(dataset: Dataset[_]): CountVectorizerModel = {
  2.  
  3. transformSchema(dataset.schema, logging = true)
  4.  
  5. if (($(minDF) >= 1.0 && $(maxDF) >= 1.0) || ($(minDF) < 1.0 && $(maxDF) < 1.0)) {
  6.  
  7. require($(maxDF) >= $(minDF), "maxDF must be >= minDF.")
  8.  
  9. }
  10.  
  11.  
  12.  
  13. val vocSize = $(vocabSize)
  14.  
  15. val input = dataset.select($(inputCol)).rdd.map(_.getSeq[String](0))
  16.  
  17. val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0
  18.  
  19. val maybeInputSize = if (countingRequired) {
  20.  
  21. if (dataset.storageLevel == StorageLevel.NONE) {
  22.  
  23. input.persist(StorageLevel.MEMORY_AND_DISK)
  24.  
  25. }
  26.  
  27. Some(input.count)
  28.  
  29. } else {
  30.  
  31. None
  32.  
  33. }
  34.  
  35. val minDf = if ($(minDF) >= 1.0) {
  36.  
  37. $(minDF)
  38.  
  39. } else {
  40.  
  41. $(minDF) * maybeInputSize.get
  42.  
  43. }
  44.  
  45. val maxDf = if ($(maxDF) >= 1.0) {
  46.  
  47. $(maxDF)
  48.  
  49. } else {
  50.  
  51. $(maxDF) * maybeInputSize.get
  52.  
  53. }
  54.  
  55. require(maxDf >= minDf, "maxDF must be >= minDF.")
  56.  
  57. val allWordCounts = input.flatMap { case (tokens) => val wc =
  58. new OpenHashMap[String, Long] tokens.foreach
  59. { w => wc.changeValue(w, 1L, _ + 1L) } wc.map
  60. { case (word, count) => (word, (count, 1)) } }
  61. .reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2)}
  62.  
  63.  
  64.  
  65.  
  66. val filteringRequired = isSet(minDF) || isSet(maxDF)
  67.  
  68. val maybeFilteredWordCounts = if (filteringRequired) {
  69.  
  70. allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf }
  71.  
  72. } else {
  73.  
  74. allWordCounts
  75.  
  76. }
  77.  
  78.  
  79.  
  80. val wordCounts = maybeFilteredWordCounts
  81.  
  82. .map { case (word, (count, _)) => (word, count) }
  83.  
  84. .persist(StorageLevel.MEMORY_AND_DISK)
  85.  
  86.  
  87.  
  88. val fullVocabSize = wordCounts.count()
  89.  
  90.  
  91.  
  92. val vocab = wordCounts
  93.  
  94. .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
  95.  
  96. .map(_._1)
  97.  
  98.  
  99.  
  100. if (input.getStorageLevel != StorageLevel.NONE) {
  101.  
  102. input.unpersist()
  103.  
  104. }
  105.  
  106. wordCounts.unpersist()
  107.  
  108.  
  109.  
  110. require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
  111.  
  112. copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
  113.  
  114. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement