ptrelford

Decision Trees

Jul 7th, 2013
277
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
F# 3.57 KB | None | 0 0
  1. open System.Collections.Generic
  2.  
  3. module internal Tuple =
  4.     open Microsoft.FSharp.Reflection
  5.     let toArray = FSharpValue.GetTupleFields
  6.  
  7. module internal Array =
  8.     let removeAt i (xs:'a[]) = [|yield! xs.[..i-1];yield! xs.[i+1..]|]
  9.  
  10. let splitDataSet(dataSet:obj[][], axis, value) = [|
  11.    for featVec in dataSet do
  12.        if featVec.[axis] = value then
  13.            yield featVec |> Array.removeAt axis
  14.    |]
  15.  
  16. let calcShannonEnt(dataSet:obj[][]) =
  17.    let numEntries = dataSet.Length
  18.    dataSet
  19.    |> Seq.countBy (fun featVec -> featVec.[featVec.Length-1])
  20.    |> Seq.sumBy (fun (key,count) ->
  21.        let prob = float count / float numEntries
  22.        -prob * log(prob)/log(2.0)
  23.    )
  24.  
  25. let chooseBestFeatureToSplit(dataSet:obj[][]) =
  26.    let numFeatures = dataSet.[0].Length
  27.    let baseEntropy = calcShannonEnt(dataSet)
  28.    [0..numFeatures-1] |> List.map (fun i ->
  29.        let featList = [for example in dataSet -> example.[i]]
  30.        let newEntropy =
  31.            let uniqueValues = Seq.distinct featList
  32.            uniqueValues |> Seq.sumBy (fun value ->
  33.                let subDataSet = splitDataSet(dataSet, i, value)
  34.                let prob = float subDataSet.Length / float dataSet.Length
  35.                prob * calcShannonEnt(subDataSet)
  36.            )
  37.        let infoGain = baseEntropy - newEntropy
  38.        i, infoGain
  39.    )
  40.    |> List.maxBy snd |> fst
  41.  
  42. let majorityCnt(classList:obj[]) =
  43.    let classCount = Dictionary()
  44.    for vote in classList do
  45.        if classCount.ContainsKey(vote) then classCount.Add(vote,0)
  46.        classCount.[vote] <- classCount.[vote] + 1
  47.    [for kvp in classCount -> kvp.Key, kvp.Value]
  48.    |> List.sortBy (snd >> (~-))
  49.    |> List.head
  50.    |> fst
  51.  
  52. type Label = string
  53. type Value = obj
  54. type Tree = Leaf of Value | Branch of Label * (Value * Tree)[]
  55.  
  56. let rec createTree(dataSet:obj[][], labels:string[]) =
  57.    let classList = [|for example in dataSet -> example.[example.Length-1]|]
  58.    if classList |> Seq.forall((=) classList.[0])
  59.    then Leaf(classList.[0])
  60.    elif dataSet.[0].Length = 1
  61.    then Leaf(majorityCnt(classList))
  62.    else
  63.    let bestFeat = chooseBestFeatureToSplit(dataSet)
  64.    let bestFeatLabel = labels.[bestFeat]
  65.    let labels = labels |> Array.removeAt bestFeat
  66.    let featValues = [|for example in dataSet -> example.[bestFeat]|]
  67.    let uniqueVals = featValues |> Seq.distinct
  68.    let subTrees =
  69.        [|for value in uniqueVals ->
  70.            let subLabels = labels.[*]
  71.            value, createTree(splitDataSet(dataSet, bestFeat, value), subLabels)|]
  72.    Branch(bestFeatLabel, subTrees)
  73.  
  74. let rec classify(inputTree, featLabels:string[], testVec:obj[]) =
  75.    match inputTree with
  76.    | Leaf(x) -> x
  77.    | Branch(s,xs) ->
  78.        let featIndex = featLabels |> Array.findIndex ((=) s)
  79.        xs |> Array.pick (fun (value,tree) ->
  80.            if testVec.[featIndex] = value
  81.            then classify(tree, featLabels,testVec) |> Some
  82.            else None
  83.        )
  84.  
  85. let myDat =
  86.    [|(1, 1, "yes"); (1, 1, "yes"); (1, 0, "no"); (0, 1, "no"); (0, 1, "no")|]
  87.    |> Array.map Tuple.toArray
  88.    
  89. let Assert condition = if not condition then failwith "Failed"
  90.  
  91. let expected : obj[][] = [|[|1; "yes"|]; [|1; "yes"|]; [|0; "no"|]|]
  92. Assert (splitDataSet(myDat,0,1) = expected)
  93. Assert (round(calcShannonEnt(myDat)*100.) = round(0.9709505945*100.))
  94. Assert (chooseBestFeatureToSplit(myDat) = 0)
  95. let labels = [|"no surfacing";"flippers"|]
  96. let myTree = createTree(myDat,labels)
  97. Assert (classify(myTree,labels,[|1;0|]) = box "no")
  98. Assert (classify(myTree,labels,[|1;1|]) = box "yes")
Advertisement
Add Comment
Please, Sign In to add comment