Advertisement
Guest User

Untitled

a guest
Nov 21st, 2014
146
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 2.31 KB | None | 0 0
  1. package dk.nezbo.ir.ass2
  2.  
  3. import scala.collection.immutable.Set
  4. import ch.ethz.dal.classifier.processing.XMLDocument
  5. import scala.util.Random
  6. import scala.collection.mutable.HashMap
  7.  
  8. class LogisticRegression(topic: String, trainSet: Seq[(Set[String],Map[String,Double])]) extends Classifier {
  9.  
  10.   // FIELDS
  11.  
  12.   val bias = 1.0 // TODO: What it should be?
  13.   val random = new Random(1337)
  14.  
  15.   var theta: Map[String,Double] = Map[String,Double]()
  16.  
  17.   // INHERITED FUNCTIONS
  18.  
  19.   def getTopic() : String = {
  20.     topic
  21.   }
  22.  
  23.   def train(iteration: Int) : Unit = {
  24.     val lRate = 1.0/iteration.toDouble
  25.    
  26.     for(doc <- trainSet){
  27.         println("["+topic+"] theta size: "+theta.size)
  28.         val positive = doc._1.contains(topic)
  29.         theta = updateTheta(doc._2, theta, lRate, positive)
  30.     }
  31.     println("["+topic+"] theta size: "+theta.size)
  32.        
  33.   }
  34.   def classify(doc: XMLDocument) : Double = {
  35.     val prob = probRelevant(Main.lrFeatures(doc),theta)
  36.     //println("Probability: "+prob)
  37.     prob
  38.   }
  39.  
  40.   override def toString = theta.toList.toString()
  41.  
  42.   // HELPER FUNCTIONS
  43.  
  44.   def innerProduct(v1: Map[String,Double], v2: Map[String,Double]) : Double = {
  45.     (v1.keySet & v2.keySet).map(k => (v1.getOrElse(k, 0.0) * v2.getOrElse(k, 0.0))).sum
  46.   }
  47.  
  48.   def probRelevant(dFeature: Map[String,Double], theta: Map[String,Double]) : Double = {
  49.     val result = 1.0 / (1.0 + Math.exp(-1.0 * bias - innerProduct(dFeature,theta)))
  50.     //println(result)
  51.     result
  52.   }
  53.  
  54.   def scalarMultVector(scalar: Double, vector: Map[String,Double]) : Map[String,Double] = {
  55.     vector.mapValues(i => i*scalar)
  56.   }
  57.  
  58.   def vectorAdd(v1: Map[String,Double], v2: Map[String,Double]) : Map[String,Double] = {
  59.     (v1.keySet ++ v2.keySet).map(k => ((k -> (v1.getOrElse(k, 0.0) + v2.getOrElse(k, 0.0)) )) ).toMap
  60.   }
  61.  
  62.   def deltaTheta(dFeature: Map[String,Double], theta: Map[String,Double], rel: Boolean) : Map[String,Double] = {
  63.     if(rel){
  64.       scalarMultVector(1.0 - probRelevant(dFeature,theta), dFeature)
  65.     }else{
  66.       scalarMultVector(-1.0*probRelevant(dFeature,theta), dFeature)
  67.     }
  68.   }
  69.  
  70.   def updateTheta(doc: Map[String,Double], theta: Map[String,Double], lRate: Double, rel: Boolean) : Map[String,Double] = {
  71.     vectorAdd(scalarMultVector(lRate, deltaTheta(doc,theta,rel)), theta)
  72.   }
  73.  
  74. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement