Advertisement
Guest User

Untitled

a guest
Mar 20th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.82 KB | None | 0 0
  1. module FwdAD(T: real): {
  2. type r = T.t
  3. type t = (r, r)
  4. val inject: r -> t
  5. val set_deriv: t -> r -> t
  6. val get_deriv: t -> r
  7. val make_dual: r -> r -> t
  8.  
  9. include from_prim with t = (r,r)
  10. include numeric with t = (r,r)
  11. include real with t = (r,r)
  12.  
  13. } = {
  14. type r = T.t
  15. type t = (r, r)
  16. let inject x = (x, T.i32 0)
  17. let i8 (x : i8) = inject (T.i8 x)
  18. let i16 (x : i16) = inject (T.i16 x)
  19. let i32 (x : i32) = inject (T.i32 x)
  20. let i64 (x : i64) = inject (T.i64 x)
  21. let f32 (x : f32) = inject (T.f32 x)
  22. let f64 (x : f64) : t = inject (T.f64 x)
  23. let u8 (x : u8) = inject (T.u8 x)
  24. let u16 x = inject (T.u16 x)
  25. let u32 x = inject (T.u32 x)
  26. let u64 x = inject (T.u64 x)
  27. let bool x = inject (T.bool x)
  28.  
  29. let (x,x') + (y,y') = T.( (x + y, x' + y') )
  30. let (x,x') - (y,y') = T.( (x - y, x' - y') )
  31. let (x,x') * (y,y') = T.( (x * y, x' * y + x * y') )
  32.  
  33. let (x,x') / (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) )
  34.  
  35. let (x,x') ** (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) )
  36.  
  37. let (x,_) == (y,_) = T.( x == y )
  38. let (x,_) < (y,_) = T.( x < y )
  39. let (x,_) > (y,_) = T.( x > y )
  40. let (x,_) <= (y,_) = T.( x <= y )
  41. let (x,_) >= (y,_) = T.( x >= y )
  42. let (x,_) != (y,_) = T.( x != y )
  43. let negate (x,x') = T.( (negate x, negate x') )
  44. let max x y = if x >= y then x else y
  45. let min x y = if x <= y then x else y
  46. let abs (x,x') = (T.abs x, x')
  47. let sgn (x,x') = (T.sgn x, x')
  48. let highest = inject T.highest
  49. let lowest = inject T.lowest
  50. -- | Returns zero on empty input.
  51. let sum = reduce (+) (inject (T.i32 0))
  52. -- | Returns one on empty input.
  53. let product = reduce (*) (inject (T.i32 1))
  54. -- | Returns `lowest` on empty input.
  55. let maximum = reduce min highest
  56. -- | Returns `highest` on empty input.
  57. let minimum = reduce max lowest
  58.  
  59.  
  60. -- val from_fraction: i32 -> i32 -> t
  61. let from_fraction x y = inject (T.from_fraction x y)
  62. -- val to_i32: t -> i32
  63. let to_i32 (x,_) = T.to_i32 x
  64. let to_i64 (x,_) = T.to_i64 x
  65. let to_f64 (x,_) = T.to_f64 x
  66.  
  67.  
  68. -- val sqrt: t -> t
  69. let sqrt (x,x') = T.( (sqrt x, x' / (i32 2 * sqrt x)) )
  70. -- val exp: t -> t
  71. let exp (x,x') = T.( (exp x, x' * exp x) )
  72. -- val cos: t -> t
  73. let cos (x, x') = T.( (cos x, negate x' * sin x) )
  74. -- val sin: t -> t
  75. let sin (x, x') = T.( (sin x, x' * cos x) )
  76. let tan x = sin x / cos x
  77. -- val asin: t -> t
  78. let asin (x, x') = T.( (asin x, x' / (sqrt (i32 1 - x ** i32 2))) )
  79. -- val acos: t -> t1
  80. let acos (x, x') = T.( (acos x, negate x' / (sqrt (i32 1 - x ** i32 2))) )
  81. -- val atan: t -> t
  82. let atan (x, x') = T.( (atan x, x' / (i32 1 + x ** i32 2)) )
  83. -- val atan2: t -> t -> t
  84. -- I know this isn't right but can't figure it out now
  85. let atan2 (x,_) (y,_) = inject (T.atan2 x y)
  86.  
  87. -- val log: t -> t
  88. let log (x, x') = T.( (log x, x' / x) )
  89. let log2 (x, x') = T.( (log10 x, i32 1 / (x' * log2 x)) )
  90. let log10 (x, x') = T.( (log10 x, i32 1 / (x' * log10 x)) )
  91.  
  92. -- val ceil : t -> t
  93. let ceil (x, x') = (T.ceil x, x')
  94. -- val floor : t -> t
  95. let floor (x, x') = (T.floor x, x')
  96. -- val trunc : t -> t
  97. let trunc (x, x') = (T.trunc x, x')
  98. -- val round : t -> t
  99. let round (x, x') = (T.round x, x')
  100.  
  101. -- val isinf: t -> bool
  102. let isinf (x,_) = T.isinf x
  103. -- val isnan: t -> bool
  104. let isnan (x,_) = T.isnan x
  105.  
  106. -- val inf: t
  107. let inf = inject T.inf
  108. -- val nan: t
  109. let nan = inject T.nan
  110.  
  111. -- val pi: t
  112. let pi = inject T.pi
  113. -- val e: t
  114. let e = inject T.e
  115.  
  116. let get_deriv (_,x') = x'
  117. let set_deriv (x,_) x'= (x,x')
  118. let make_dual x x' = (x,x')
  119. }
  120.  
  121. import "lib/github.com/diku-dk/linalg/linalg"
  122.  
  123. module d = FwdAD f64
  124. module l = mk_linalg d
  125.  
  126. entry paper_dotprod [n] (v1: [n]f64) (v2: [n]f64): []f64 =
  127. tabulate n
  128. (\i ->
  129. let v1' = map2 d.make_dual v1 (tabulate n (\j -> if i == j then 1 else 0))
  130. let v2' = map d.inject v2
  131. in (l.dotprod v1' v2'))
  132. |> map (.2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement