Share Pastebin
Guest
Public paste!

Untitled

By: a guest | Mar 20th, 2010 | Syntax: OCaml | Size: 6.81 KB | Hits: 61 | Expires: Never
Copy text to clipboard
  1. type trainData = float list list
  2. type lr = float
  3. type maxiterations = int
  4. type winit = float list
  5.  
  6. (*Initialize random number sequence *)
  7. let () = Random.self_init();;
  8.  
  9. (* get a random number from -0.5 to 0.5 *)
  10. let get_rand() = (Random.float 1.0) -. 0.5;;
  11.  
  12. (* gets # of columns *)
  13. let head_length trainData =     List.length(List.hd trainData)
  14.  
  15. (*initialize a new list *)
  16.  
  17. let create_winit trainData =
  18.         let rec loop w_initial = function 1 -> w_initial | x -> loop (get_rand() :: w_initial) (x - 1) in loop [] (head_length(trainData))
  19. ;;
  20.    
  21. (*save only the last element from the one row *)       
  22. let remove_head x = List.hd(List.rev x)
  23.  
  24. (*save only the last element from each row to create the vector class*)
  25. let create_vector_class matrix = List.map remove_head matrix
  26.  
  27. (*remove the last element from the one row *)  
  28. let remove_last x = List.rev (List.tl (List.rev x))
  29.  
  30. (*remove the last element from each row to create the input vector*)
  31. let create_input_vector matrix = List.map remove_last matrix
  32.  
  33. (* this will be inner.caml *)
  34. let rec inner = function [],[] -> 0.0 | i::is,w::ws -> i*.w +. inner (is,ws) | _ -> failwith "different length lists"
  35.  
  36. (*multiply a matrix and a vector together *)   
  37. let rec multiply_test = function  [],_ -> [] | m::ms, v -> inner(m,v) :: multiply_test(ms,v)
  38.  
  39. (* create the current class vector *)
  40. let cur_class_vect inData w = multiply_test(create_input_vector inData, w)
  41.  
  42. let matrix_multiply inData w = multiply_test(inData, w)
  43.  
  44. (*Compares 2 vectors: if correct = true then bool = "true" *)
  45. let compare_vc_current_class input1 input2 = List.for_all2 (fun x y -> x > 0.0 && y > 0.0 || x < 0.0 && y < 0.0) input1 input2
  46.  
  47. (*Add 2 vectors together *)
  48. let vector_add input1 input2 = List.map2 (+.) input1 input2
  49.  
  50. (*Mult 2 vectors together *)
  51. let vector_mult input1 input2 = List.map2 ( *. ) input1 input2
  52.  
  53. (*duplicate the learning rate to create a learning rate vector *)
  54. (* n = int  := number of times you want to duplicate*)
  55. (* x = float:= you want to duplicate*)
  56. let dup n x =
  57.         let rec f n accum =
  58.                 if n <= 0 then accum
  59.                 else f (n - 1) (x :: accum) in f n []
  60.  
  61. let signal_of_floats a b =
  62.   if a < 0.0 && b < 0.0 then
  63.     0.0
  64.   else if a > 0.0 && b > 0.0 then
  65.     0.0
  66.   else
  67.     a
  68.  
  69. let rec my_compare a b =
  70.   match a, b with
  71.     | [], [] -> []
  72.     | a::b, []
  73.     | [], a::b -> failwith "my_compare: the lists has different sizes"
  74.     | a::r, b::s -> (signal_of_floats a b)::(my_compare r s)
  75.    
  76. (*let dotproduct = List.fold_left2 (fun x a b -> x +. a *. b) 0.0;; *)
  77. let transpose m = List.map List.rev (List.fold_left (List.map2 (fun xs x -> x::xs)) (List.map (fun x -> [x]) (List.hd m)) (List.tl m));;
  78. (* let mul m1 m2 = List.map (fun v1 -> List.map (fun v2 -> dotproduct v1 v2) (transpose m2)) m1;; *)
  79.  
  80. let ( *- ) a b = List.fold_left2 (fun x a b -> x +. a *. b) 0.0 a b    
  81. let mult matrix vector = List.map (fun x -> x *- vector) matrix
  82.  
  83. let make_delayed_printer print_function =
  84.   let temporary = ref None in
  85.   let delayed_match value = match !temporary with
  86.      | None -> Some value
  87.      | Some stored -> print_function stored; Some value in
  88.   let delayed_printer value =
  89.     temporary := (delayed_match value) in
  90.   delayed_printer
  91.  
  92. let print_one_entry (count, w_current) =
  93.   Printf.printf "at At iter: %d current weights are: " count;
  94.   List.iter (Printf.printf "%f ") w_current;   
  95.   Printf.printf "\n"
  96.  
  97. let print_it_delayed =
  98.   make_delayed_printer print_one_entry
  99.  
  100. (* sde3 function *)
  101. let sde3(trainData, lr, maxiterations, winit)=
  102.        
  103.         let lr_size = (head_length(trainData)-1) in
  104.                 if lr_size = 0 then failwith "lr_size can't be 0"
  105.         ;
  106.  
  107.         let lr_vector = dup lr_size lr in
  108.                 if lr_vector = [] then failwith "lr vector can't be empty"
  109.         ;
  110.        
  111.         (*CREATE THE INPUT VECTOR*)
  112.         let input_vector  = create_input_vector trainData in
  113.                 if input_vector = [] then failwith "input_vector can't be empty"
  114.         ;      
  115.  
  116.         (*CREATE VECTOR CLASS*)
  117.         let vector_class  = create_vector_class trainData in
  118.                 if vector_class = [] then failwith "vector class cant be empty"
  119.         ;
  120.  
  121.         (*CREATE W_CURRENT*)
  122.         let w_current = winit in
  123.                 if w_current = [] then failwith "w_current cant be empty"
  124.         ;
  125.        
  126.         (*CREATE THE CURRENT CLASS*)
  127.         let current_class = cur_class_vect trainData winit in
  128.                 if current_class = [] then failwith "current class cant be empty"
  129.         ;              
  130.        
  131.         (*Printf.printf "Vector Class: [";
  132.         List.iter (Printf.printf "%f ") vector_class;  
  133.         Printf.printf "]\n";
  134.        
  135.         Printf.printf "Current Class: [";
  136.         List.iter (Printf.printf "%f ") current_class; 
  137.         Printf.printf "]\n\n\n";
  138.         *)
  139.        
  140.         let new_mult = my_compare vector_class current_class in
  141.                 if new_mult = [] then failwith "new_mult cant be empty"
  142.         ;
  143.  
  144.         (*Printf.printf "new_mult: ";
  145.         List.iter (Printf.printf "%f ") new_mult;      
  146.         Printf.printf "\n";
  147.         *)
  148.        
  149.         let new_vect = matrix_multiply (transpose(create_input_vector trainData)) new_mult in
  150.                 if new_vect = [] then failwith "new_vect cant be empty"
  151.         ;
  152.  
  153.         (*
  154.         Printf.printf "new_vect: ";
  155.         List.iter (Printf.printf "%f ") new_vect;      
  156.         Printf.printf "\n";
  157.         *)
  158.         let error_vect =        vector_mult lr_vector new_vect in
  159.                 if error_vect = [] then failwith "error_vect cant be empty"
  160.         ;      
  161.        
  162.         (*
  163.         Printf.printf "error_vect: ";
  164.         List.iter (Printf.printf "%f ") error_vect;    
  165.         Printf.printf "\n";
  166.         *)
  167.  
  168.         let rec finish w_current current_class new_mult new_vect error_vect count=
  169.                
  170.                 let is_finished =  compare_vc_current_class vector_class current_class in
  171.                
  172.                         print_it_delayed(count, w_current);
  173.                        
  174.                         if is_finished = true || count = maxiterations then  w_current
  175.                         else begin
  176.                        
  177.  
  178.  
  179.                                
  180.                                 (*
  181.                                 Printf.printf "at At iter: %d w_current: " !count;
  182.                                 List.iter (Printf.printf "%f ") w_current;     
  183.                                 Printf.printf "\n";
  184.  
  185.                                 Printf.printf "at At iter: %d current_class: " !count;
  186.                                 List.iter (Printf.printf "%f ") current_class; 
  187.                                 Printf.printf "\n";
  188.                                                                                                                                
  189.                                 Printf.printf "at At iter: %d new_mult: " !count;
  190.                                 List.iter (Printf.printf "%f ") new_mult;      
  191.                                 Printf.printf "\n";
  192.                        
  193.                                 Printf.printf "at At iter: %d new_vector is: " !count;
  194.                                 List.iter (Printf.printf "%f ") new_vect;      
  195.                                 Printf.printf "\n";    
  196.  
  197.                                 Printf.printf "at At iter: %d lr_vector is: " !count;
  198.                                 List.iter (Printf.printf "%f ") lr_vector;     
  199.                                 Printf.printf "\n";
  200.                                                        
  201.                                 Printf.printf "at At iter: %d error_vect: " !count;
  202.                                 List.iter (Printf.printf "%f ") error_vect;    
  203.                                 Printf.printf "\n";
  204.                                 *)                             
  205.  
  206.                                 let current_class = cur_class_vect trainData w_current in
  207.                                 let new_mult = my_compare vector_class current_class in
  208.                                 let new_vect = matrix_multiply (transpose(create_input_vector trainData)) new_mult in
  209.                                 let error_vect = vector_mult lr_vector new_vect in
  210.                                 let w_current = vector_add error_vect w_current in
  211.  
  212.                                 finish w_current current_class new_mult new_vect error_vect (count+1)
  213.                         end
  214.         in
  215.         finish w_current current_class new_mult new_vect error_vect 0