finalmail

dirty id3 for play-tennis

Dec 27th, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. const fs = require("fs")
  2.  
  3. // const NO = "no-recurrence-events"
  4. // const YES = "recurrence-events"
  5. const NO = "No"
  6. const YES = "Yes"
  7. const NAMES = ["", "Outlook", "Temperature", "Humidity", "Wind"]
  8. // utils
  9. const splitBy = del => str => str.split(del)
  10. const nonEmpty = el => Boolean(el)
  11. const atIndexEquals = (ind, value) => el => el[ind] === value
  12. const prop = p => obj => obj[p]
  13. const uniq = collection => [...new Set(collection)]
  14. //
  15. const entropyFn = (p, n) =>
  16.   p === 0 || n === 0
  17.     ? 0
  18.     : (-p / (p + n)) * Math.log2(p / (p + n)) +
  19.       (-n / (p + n)) * Math.log2(n / (p + n))
  20. //
  21.  
  22. const findBestNode = (trainingData, attributeIndices) => {
  23.   const pos = trainingData.filter(atIndexEquals(0, YES)).length
  24.   const neg = trainingData.filter(atIndexEquals(0, NO)).length
  25.   const globalEntropy = entropyFn(pos, neg)
  26.  
  27.   let bestNode = undefined
  28.   if (attributeIndices.length === 1) {
  29.     const attrIndex = attributeIndices[0]
  30.     const classes = trainingData.map(prop(0))
  31.     console.log("classes", classes)
  32.   } else {
  33.     attributeIndices.forEach(attrIndex => {
  34.       //
  35.       const currAttrValues = uniq(trainingData.map(prop(attrIndex)))
  36.       const entropyForValues = currAttrValues.reduce((acc, value) => {
  37.         const samplesWithValue = trainingData.filter(
  38.           atIndexEquals(attrIndex, value)
  39.         )
  40.         const pos = samplesWithValue.filter(atIndexEquals(0, YES)).length
  41.         const neg = samplesWithValue.filter(atIndexEquals(0, NO)).length
  42.         const entropy = entropyFn(pos, neg)
  43.  
  44.         return acc.concat({ entropy, pos, neg, value })
  45.       }, [])
  46.       const avgInfEntropy = entropyForValues.reduce(
  47.         (acc, curr) =>
  48.           acc + ((curr.pos + curr.neg) / (pos + neg)) * curr.entropy,
  49.         0
  50.       )
  51.       const gain = globalEntropy - avgInfEntropy
  52.       // console.log(pos, neg, globalEntropy)
  53.       // console.log(`----${NAMES[attrIndex]}----`)
  54.       // console.log(entropyForValues)
  55.       // console.log("avgInfEntropy", avgInfEntropy)
  56.       // console.log(gain)
  57.       if (!bestNode || bestNode.gain < gain) {
  58.         bestNode = {
  59.           attrIndex,
  60.           attrName: NAMES[attrIndex],
  61.           gain,
  62.           values: gain !== 0 ? currAttrValues : trainingData[0][0]
  63.         }
  64.       }
  65.     })
  66.   }
  67.  
  68.   if (Array.isArray(bestNode.values)) {
  69.     for (value of bestNode.values) {
  70.       if (!bestNode.children) {
  71.         bestNode.children = []
  72.       }
  73.  
  74.       const child = findBestNode(
  75.         trainingData.filter(atIndexEquals(bestNode.attrIndex, value)),
  76.         attributeIndices.filter(el => el !== bestNode.attrIndex)
  77.       )
  78.       bestNode.children.push(child)
  79.     }
  80.   }
  81.  
  82.   // const rl = findBestNode(
  83.   //   trainingData.filter(atIndexEquals(1, "Sunny")),
  84.   //   attributeIndices.filter(el => el !== 1)
  85.   // )
  86.  
  87.   return bestNode
  88. }
  89.  
  90. const predict = (root, entry) => {
  91.   // if (!root) ??
  92.   if (Array.isArray(root.values)) {
  93.     const classValue = entry[root.attrIndex]
  94.     const valueIndex = root.values.indexOf(classValue)
  95.     const nextChild = root.children[valueIndex]
  96.  
  97.     return predict(nextChild, entry)
  98.   }
  99.  
  100.   return root.values
  101. }
  102.  
  103. //
  104.  
  105. // read file
  106. const file = fs.readFileSync("./play-tennis.data", "utf8")
  107. // const file = fs.readFileSync("./breast-cancer.data", "utf8")
  108. const data = file
  109.   .split("\n")
  110.   .filter(nonEmpty)
  111.   .map(splitBy(","))
  112. const step = 2
  113. // const step = ~~(data.length / 10)
  114. //
  115.  
  116. // solution
  117. const results = []
  118. const i = 0
  119. // for (let i = 0; i < 10; i++) {
  120. const testData =
  121.   i === 9 ? data.slice(i * step) : data.slice(i * step, i * step + step)
  122. const trainingData = data.slice(0, i * step).concat(data.slice(i * step + step))
  123. // const trainingData = data
  124. const attributeIndices = Array.from(
  125.   { length: trainingData[0].length - 1 },
  126.   (_, i) => i + 1
  127. )
  128.  
  129. let root = findBestNode(trainingData, attributeIndices)
  130. // const rl = findBestNode(
  131. //   trainingData.filter(atIndexEquals(1, "Sunny")),
  132. //   attributeIndices.filter(el => el !== 1)
  133. // )
  134. // const rm = findBestNode(
  135. //   trainingData.filter(atIndexEquals(1, "Overcast")),
  136. //   attributeIndices.filter(el => el !== 1)
  137. // )
  138. // const rll = findBestNode(
  139. //   trainingData
  140. //     .filter(atIndexEquals(1, "Sunny"))
  141. //     .filter(atIndexEquals(3, "High")),
  142. //   attributeIndices.filter(el => el !== 1 && el !== 3)
  143. // )
  144. // const rlr = findBestNode(
  145. //   trainingData
  146. //     .filter(atIndexEquals(1, "Sunny"))
  147. //     .filter(atIndexEquals(3, "Normal")),
  148. //   attributeIndices.filter(el => el !== 1 && el !== 3)
  149. // )
  150. const stringifiedRoot = JSON.stringify(root)
  151. const correctAnswers = testData.filter(currTest => {
  152.   const res = currTest[0]
  153.   const prediction = predict(root, currTest)
  154.  
  155.   return res === prediction
  156. }).length
  157. const accuracy = correctAnswers / testData.length
  158. console.log(accuracy)
  159. // console.log(prediction)
  160. // console.log(stringifiedRoot)
  161. // fs.writeFile("myjsonfile.json", stringifiedRoot, "utf8", () => {})
  162. // console.log(rl)
  163. // console.log(rm)
  164. // console.log(rll)
  165. // console.log(rlr)
  166.  
  167. // global
  168.  
  169. // results.push(accuracy)
  170. // }
  171.  
  172. // const classifierAccuracy =
  173. //   results.reduce((acc, res, i) => {
  174. //     if (i === 9) {
  175. //       return res * (data.length - 9 * step) + acc
  176. //     }
  177.  
  178. //     return res * step + acc
  179. //   }, 0) / data.length
  180. // console.log("Classifier: ", classifierAccuracy)
Add Comment
Please, Sign In to add comment