Guest User

Untitled

a guest
Jun 23rd, 2018
166
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.61 KB | None | 0 0
  1. (ns mira
  2. {:doc "Implements margin-infused relaxation algorithm (MIRA). Fairly optimized."
  3. :author "Aria Haghighi <me@aria42.com>"}
  4. (:gen-class)
  5. (:use [clojure.string :only [join]]
  6. [clojure.java.io :only [reader]]))
  7.  
  8. (defn dot-product
  9. "dot-product between two maps (sum over matching values)
  10. Bottleneck: written to be efficient"
  11. [x y]
  12. (loop [sum 0.0 y y]
  13. (let [f (first y)]
  14. (if-not f sum
  15. (let [k (first f) v (second f)]
  16. (recur (+ sum (* (get x k 0.0) v))
  17. (rest y)))))))
  18.  
  19. (defn sum [f xs]
  20. (reduce + (map f xs)))
  21.  
  22. (defn norm-sq
  23. "||x||^2 over values in map x"
  24. [x] (sum #(* % %) (map second x)))
  25.  
  26. (defn add-scaled
  27. "x <- x + scale * y
  28. Bottleneck: written to be efficient"
  29. [x scale y]
  30. (persistent!
  31. (reduce
  32. (fn [res elem]
  33. (let [k (first elem) v (second elem)]
  34. (assoc! res k (+ (get x k 0.0) (* scale v)))))
  35. (transient x)
  36. y)))
  37.  
  38. ; Needed for averaged weight vector
  39. (def +updates-left+ (atom nil))
  40.  
  41. ; (cum)-label-weights: label -> (cum)-weights
  42. (defrecord Mira [loss-fn label-weights cum-label-weights])
  43.  
  44. (defn new-mira
  45. [labels loss-fn]
  46. (let [empty-weights #(into {} (for [l labels] [l {}]))]
  47. (Mira. loss-fn (empty-weights) (empty-weights))))
  48.  
  49. (defn get-labels
  50. "return possible labels for task"
  51. [mira] (keys (:label-weights mira)))
  52.  
  53. (defn get-score-fn
  54. "return fn: label => model-score-of-label"
  55. [mira datum]
  56. (fn [label]
  57. (dot-product ((:label-weights mira) label) datum)))
  58.  
  59. (defn get-loss
  60. "get loss for predicting predict-label in place of gold-label"
  61. [mira gold-label predict-label]
  62. ((:loss-fn mira) gold-label predict-label))
  63.  
  64. (defn ppredict
  65. "When you have lots of classes, useful to parallelize prediction"
  66. [mira datum]
  67. (let [score-fn (get-score-fn mira datum)
  68. label-parts (partition-all 5 (get-labels mira))
  69. part-fn (fn [label-part]
  70. (reduce
  71. (fn [res label] (assoc res label (score-fn label)))
  72. {} label-part))
  73. score-parts (pmap part-fn label-parts)
  74. scores (apply merge score-parts)]
  75. (first (apply max-key second scores))))
  76.  
  77. (defn predict
  78. "predict highest scoring class"
  79. [mira datum]
  80. (if (> (count (get-labels mira)) 5)
  81. (ppredict mira datum)
  82. (apply max-key (get-score-fn mira datum) (get-labels mira))))
  83.  
  84. (defn update-weights
  85. "returns new weights assuming error predict-label instead of gold-label.
  86. delta-vec is the direction and alpha the scaling constant"
  87. [label-weights delta-vec gold-label predict-label alpha]
  88. (-> label-weights
  89. (update-in [gold-label] add-scaled alpha delta-vec)
  90. (update-in [predict-label] add-scaled (- alpha) delta-vec)))
  91.  
  92. (defn update-mira
  93. "update mira for an example returning [new-mira error?]"
  94. [mira datum gold-label]
  95. (let [predict-label (ppredict mira datum)]
  96. (if (= predict-label gold-label)
  97. ; If we get it right do nothing
  98. [mira false]
  99. ; otherwise, update weights
  100. (let [score-fn (get-score-fn mira datum)
  101. loss (get-loss mira gold-label predict-label)
  102. gap (- (score-fn gold-label) (score-fn predict-label))
  103. alpha (/ (- loss gap) (* 2 (norm-sq datum)))
  104. avg-factor (* @+updates-left+ alpha)
  105. new-mira (-> mira
  106. ; Update Current Weights
  107. (update-in [:label-weights]
  108. update-weights datum gold-label predict-label alpha)
  109. ; Update Average (cumulative) Weights
  110. (update-in [:cum-label-weights]
  111. update-weights datum gold-label
  112. predict-label avg-factor))]
  113. [new-mira true]))))
  114.  
  115. (defn train-iter
  116. "Training pass over data, returning [new-mira num-errors], where
  117. num-errors is the number of mistakes made on training pass"
  118. [mira labeled-data-fn]
  119. (reduce
  120. (fn [[cur-mira num-errors] [datum gold-label]]
  121. (let [[new-mira error?] (update-mira cur-mira datum gold-label)]
  122. (swap! +updates-left+ dec)
  123. [new-mira (if error? (inc num-errors) num-errors)]))
  124. [mira 0]
  125. (labeled-data-fn)))
  126.  
  127. (defn train
  128. "do num-iters iterations over labeled-data (yielded by labeled-data-fn)"
  129. [labeled-data-fn labels num-iters loss-fn]
  130. (loop [iter 0 mira (new-mira labels loss-fn)]
  131. (if (= iter num-iters)
  132. mira
  133. (let [[new-mira num-errors] (train-iter mira labeled-data-fn)]
  134. (println
  135. (format "[MIRA] On iter %s made %s training mistakes" iter num-errors))
  136. ; If we don't make mistakes, never will again
  137. (if (zero? num-errors)
  138. new-mira (recur (inc iter) new-mira))))))
  139.  
  140. (defn feat-vec-from-line
  141. "format: feat1:val1 ... featn:valn. feat is a string and val a double"
  142. [#^String line]
  143. (for [#^String piece (.split line "\\s+")
  144. :let [split-index (.indexOf piece ":")
  145. feat (if (neg? split-index)
  146. piece
  147. (.substring piece 0 split-index))
  148. value (if (neg? split-index) 1
  149. (-> piece (.substring (inc split-index))
  150. Double/parseDouble))]]
  151. [feat value]))
  152.  
  153. (defn load-labeled-data
  154. "format: label feat1:val1 .... featn:valn"
  155. [path]
  156. (for [line (line-seq (reader path))
  157. :let [pieces (.split #^String line "\\s+")
  158. label (first pieces)
  159. feat-vec (feat-vec-from-line (join " " (rest pieces)))]]
  160. [feat-vec label]))
  161.  
  162. (defn load-data
  163. "load data without label"
  164. [path] (map feat-vec-from-line (line-seq (reader path))))
  165.  
  166. (defn normalize-vec [x]
  167. (let [norm (Math/sqrt (norm-sq x))]
  168. (into {} (for [[k v] x] [k (/ v norm)]))))
  169.  
  170. (defn -main [& args]
  171. (case (first args)
  172. "train"
  173. (let [[data-path num-iters outfile] (rest args)
  174. labeled-data-fn #(load-labeled-data data-path)
  175. labels (into #{} (map second (labeled-data-fn)))
  176. num-iters (Integer/parseInt num-iters)]
  177. ; For Average Weight Calculation
  178. (compare-and-set! +updates-left+ nil (* num-iters (count (labeled-data-fn))))
  179. (let [mira (train labeled-data-fn labels num-iters (constantly 1))
  180. avg-weights (into {}
  181. (for [[label sum-weights] (:cum-label-weights mira)]
  182. [label (normalize-vec sum-weights)]))]
  183. (println "[MIRA] Done Training. Writing weights to " outfile)
  184. (spit outfile avg-weights)))
  185. "predict"
  186. (let [[weight-file data-file] (rest args)
  187. weights (read-string (slurp weight-file))
  188. mira (Mira. (constantly 1) weights weights)]
  189. (doseq [datum (load-data data-file)]
  190. (println (predict mira datum))))
  191. "test"
  192. (let [[weight-file data-file] (rest args)
  193. weights (read-string (slurp weight-file))
  194. mira (Mira. (constantly 1) weights weights)
  195. labeled-test (load-labeled-data data-file)
  196. gold-labels (map second labeled-test)
  197. predict-labels (map #(predict mira %) (map first labeled-test))
  198. num-errors (->> (map vector gold-labels predict-labels)
  199. (sum (fn [[gold predict]] (if (not= gold predict) 1 0))))]
  200. (println "Error: " (double (/ num-errors (count gold-labels)))))))
Add Comment
Please, Sign In to add comment