Guest User

Untitled

a guest
Feb 18th, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.17 KB | None | 0 0
  1. ;;; -*- coding:utf-8; mode:lisp -*-
  2.  
  3. (in-package :cl-random-forest)
  4.  
  5. ;;; Small dataset ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  6.  
  7. (defparameter *n-class* 4)
  8.  
  9. (defparameter *target*
  10. (make-array 11 :element-type 'fixnum
  11. :initial-contents '(0 0 1 1 2 2 2 3 3 3 3)))
  12.  
  13. (defparameter *datamatrix*
  14. (make-array '(11 2)
  15. :element-type 'double-float
  16. :initial-contents '((-1.0d0 -2.0d0)
  17. (-2.0d0 -1.0d0)
  18. (1.0d0 -2.0d0)
  19. (3.0d0 -1.5d0)
  20. (-2.0d0 2.0d0)
  21. (-3.0d0 1.0d0)
  22. (-2.0d0 1.0d0)
  23. (3.0d0 2.0d0)
  24. (2.0d0 2.0d0)
  25. (1.0d0 2.0d0)
  26. (1.0d0 1.0d0))))
  27.  
  28. ;;; Decision tree ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  29.  
  30. ;; make decision tree
  31. (defparameter *dtree*
  32. (make-dtree *n-class* *datamatrix* *target*
  33. :max-depth 5 :min-region-samples 1 :n-trial 10))
  34.  
  35. ;; prediction
  36. (predict-dtree *dtree* *datamatrix* 2)
  37. (predict-dtree-majority-vote *dtree* *datamatrix* 2)
  38.  
  39. (defun predict-node (node)
  40. (let ((max 0d0)
  41. (max-class 0)
  42. (dist (node-class-distribution node))
  43. (n-class (dtree-n-class (node-dtree node))))
  44. (loop for i fixnum from 0 to (1- n-class) do
  45. (when (> (aref dist i) max)
  46. (setf max (aref dist i)
  47. max-class i)))
  48. max-class))
  49.  
  50. (defun extract-node (node)
  51. (if (and node (node-test-attribute node))
  52. `(if (>= (aref d i ,(node-test-attribute node)) ,(node-test-threshold node))
  53. ,(extract-node (node-left-node node))
  54. ,(extract-node (node-right-node node)))
  55. (predict-node node)))
  56.  
  57. (defun construct-dtree-lambda (dtree)
  58. `(lambda (d i)
  59. (declare (optimize (speed 3) (space 0) (safety 0) (debug 0) (compilation-speed 0))
  60. (type (simple-array double-float) d)
  61. (type fixnum i))
  62. ,(extract-node (dtree-root dtree))))
  63.  
  64. (construct-dtree-lambda *dtree*)
  65.  
  66. ;; 生成されるlambda式
  67. (LAMBDA (DATAMATRIX DATUM-INDEX)
  68. (DECLARE
  69. (OPTIMIZE (SPEED 3) (SPACE 0) (SAFETY 0) (DEBUG 0) (COMPILATION-SPEED 0))
  70. (TYPE (SIMPLE-ARRAY DOUBLE-FLOAT) DATAMATRIX)
  71. (TYPE FIXNUM DATUM-INDEX))
  72. (IF (>= (AREF DATAMATRIX DATUM-INDEX 0) -0.7394168078526362d0)
  73. (IF (>= (AREF DATAMATRIX DATUM-INDEX 1) 0.8903535809681147d0)
  74. 3
  75. 1)
  76. (IF (>= (AREF DATAMATRIX DATUM-INDEX 1) 0.5876648784761986d0)
  77. 2
  78. 0)))
  79.  
  80. ;; コンパイル
  81. (defparameter compiled-dtree (compile nil (construct-dtree-lambda *dtree*)))
  82.  
  83. ;;呼び出し
  84. (funcall compiled-dtree *datamatrix* 0)
  85.  
  86. ;;;;;
  87.  
  88. (defparameter mnist-dim 784)
  89. (defparameter mnist-n-class 10)
  90.  
  91. (let ((mnist-train (clol.utils:read-data "/home/wiz/datasets/mnist.scale" mnist-dim :multiclass-p t))
  92. (mnist-test (clol.utils:read-data "/home/wiz/datasets/mnist.scale.t" mnist-dim :multiclass-p t)))
  93.  
  94. ;; Add 1 to labels in order to form class-labels beginning from 0
  95. (dolist (datum mnist-train) (incf (car datum)))
  96. (dolist (datum mnist-test) (incf (car datum)))
  97.  
  98. (multiple-value-bind (datamat target)
  99. (clol-dataset->datamatrix/target mnist-train)
  100. (defparameter mnist-datamatrix datamat)
  101. (defparameter mnist-target target))
  102.  
  103. (multiple-value-bind (datamat target)
  104. (clol-dataset->datamatrix/target mnist-test)
  105. (defparameter mnist-datamatrix-test datamat)
  106. (defparameter mnist-target-test target)))
  107.  
  108. ;;; Make Decision Tree ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
  109.  
  110. (defparameter mnist-dtree
  111. (make-dtree mnist-n-class mnist-datamatrix mnist-target
  112. :max-depth 10 :n-trial 28 :min-region-samples 5))
  113.  
  114. (test-dtree mnist-dtree mnist-datamatrix mnist-target)
  115. (test-dtree mnist-dtree mnist-datamatrix-test mnist-target-test)
  116.  
  117. (time
  118. (loop repeat 100 do
  119. (loop for i from 0 below (array-dimension mnist-datamatrix 0) do
  120. (predict-dtree mnist-dtree mnist-datamatrix i))))
  121.  
  122. (time (defparameter dtree-predictor (compile nil (construct-dtree-lambda mnist-dtree))))
  123. (time (defparameter dtree-lambda (construct-dtree-lambda mnist-dtree))) ; これは非常に高速
  124.  
  125. (time
  126. (loop repeat 100 do
  127. (loop for i from 0 below (array-dimension mnist-datamatrix 0) do
  128. (funcall dtree-predictor mnist-datamatrix i))))
  129.  
  130. (defparameter mnist-forest
  131. (make-forest mnist-n-class mnist-datamatrix mnist-target
  132. :n-tree 500 :bagging-ratio 0.1 :max-depth 5 :n-trial 28 :min-region-samples 5))
  133.  
  134. (time
  135. (defparameter dtree-predictor-list
  136. (loop for dtree in (forest-dtree-list mnist-forest)
  137. for i from 0
  138. collect (progn
  139. (print i)
  140. (compile nil (construct-dtree-lambda dtree))))))
  141.  
  142. (defun argmax (arr)
  143. (let ((max 0)
  144. (max-i 0))
  145. (loop for i from 0 below (length arr) do
  146. (when (> (aref arr i) max)
  147. (setf max (aref arr i)
  148. max-i i)))
  149. max-i))
  150.  
  151. (defun predict-dtree-predictor-list (dtree-predictor-list datamatrix index)
  152. (let ((cnt (make-array (array-dimension datamatrix 1))))
  153. (loop for predictor in dtree-predictor-list do
  154. (incf (aref cnt (funcall predictor datamatrix index))))
  155. (argmax cnt)))
  156.  
  157. (predict-dtree-predictor-list dtree-predictor-list mnist-datamatrix 0)
  158.  
  159. (defun test-dtree-predictor-list (dtree-predictor-list datamatrix target)
  160. (loop for i from 0 below (array-dimension datamatrix 0)
  161. count (= (predict-dtree-predictor-list dtree-predictor-list datamatrix i)
  162. (aref target i))))
  163.  
  164. (time (test-dtree-predictor-list dtree-predictor-list mnist-datamatrix-test mnist-target-test))
  165. ;; 9385
  166. ;; Evaluation took:
  167. ;; 0.286 seconds of real time
  168. ;; 0.284000 seconds of total run time (0.284000 user, 0.000000 system)
  169. ;; 99.30% CPU
  170. ;; 967,442,117 processor cycles
  171. ;; 62,876,912 bytes consed
  172.  
  173. (time (test-forest mnist-forest mnist-datamatrix-test mnist-target-test))
  174. ;; Accuracy: 94.33%, Correct: 9433, Total: 10000
  175. ;; Evaluation took:
  176. ;; 2.659 seconds of real time
  177. ;; 2.660000 seconds of total run time (2.660000 user, 0.000000 system)
  178. ;; 100.04% CPU
  179. ;; 9,021,268,236 processor cycles
  180. ;; 1,216 bytes consed
  181.  
  182. ;; 事前に全ての葉の予測値を出しておく方式(多数決 majority-vote)
  183. ;; predict-all-leaf
  184.  
  185. (defparameter leaf1 (find-leaf (dtree-root mnist-dtree) mnist-datamatrix 0))
  186. (argmax (node-class-distribution leaf1))
  187.  
  188. (defun set-leaf-prediction! (dtree)
  189. (do-leaf (lambda (node)
  190. (setf (node-leaf-prediction node)
  191. (argmax (node-class-distribution node))))
  192. (dtree-root dtree)))
  193.  
  194. (defun set-leaf-prediction-forest! (forest)
  195. (dolist (dtree (forest-dtree-list forest))
  196. (set-leaf-prediction! dtree)))
  197.  
  198. (time (set-leaf-prediction mnist-dtree))
  199. (time (set-leaf-prediction-forest mnist-forest))
  200.  
  201. (defun predict-dtree-majority-vote (dtree datamatrix datum-index)
  202. (node-leaf-prediction (find-leaf (dtree-root dtree) datamatrix datum-index)))
  203.  
  204. (defun test-dtree-majority-vote (dtree datamatrix target &key quiet-p)
  205. (declare (optimize (speed 3) (safety 0))
  206. (type dtree dtree)
  207. (type (simple-array double-float) datamatrix)
  208. (type (simple-array fixnum (*)) target))
  209. (let ((n-correct 0)
  210. (len (length target)))
  211. (declare (type fixnum n-correct len))
  212. (loop for i fixnum from 0 below len do
  213. (when (= (predict-dtree-majority-vote dtree datamatrix i)
  214. (aref target i))
  215. (incf n-correct)))
  216. (calc-accuracy n-correct len :quiet-p quiet-p)))
  217.  
  218. (defun predict-forest-majority-vote (forest datamatrix datum-index)
  219. (let ((class-count-array (forest-class-count-array forest)))
  220. ;; init class-count-array
  221. (loop for i fixnum from 0 below (length class-count-array) do
  222. (setf (aref class-count-array i) 0d0))
  223. (dolist (dtree (forest-dtree-list forest))
  224. (let ((predicted-class
  225. (node-leaf-prediction (find-leaf (dtree-root dtree) datamatrix datum-index))))
  226. (incf (aref class-count-array predicted-class) 1.0d0)))
  227. (argmax class-count-array)))
  228.  
  229. (defun test-forest-majority-vote (forest datamatrix target &key quiet-p)
  230. (declare (optimize (speed 3) (safety 0))
  231. (type forest forest)
  232. (type (simple-array double-float) datamatrix)
  233. (type (simple-array fixnum) target))
  234. (let ((n-correct 0)
  235. (len (length target)))
  236. (declare (type fixnum n-correct len))
  237. (loop for i fixnum from 0 below len do
  238. (when (= (predict-forest-majority-vote forest datamatrix i)
  239. (aref target i))
  240. (incf n-correct)))
  241. (calc-accuracy n-correct len :quiet-p quiet-p)))
  242.  
  243. (time (test-forest-majority-vote mnist-forest mnist-datamatrix mnist-target))
  244. (time (test-forest-majority-vote mnist-forest mnist-datamatrix-test mnist-target-test))
  245. ;;
  246. (ql:quickload :wiz-util)
  247.  
  248. (time (set-leaf-prediction *dtree*))
  249. (do-leaf (lambda (node)
  250. (node-class-distribution node)
  251. (node-leaf-prediction node)
  252. )
  253. (dtree-root *dtree*))
  254.  
  255. (require :sb-sprof)
  256. (sb-sprof:
  257.  
  258. (sb-sprof:with-profiling (:max-samples 1000
  259. :report :flat
  260. :loop nil)
  261. (test-forest-majority-vote mnist-forest mnist-datamatrix mnist-target))
Add Comment
Please, Sign In to add comment