Advertisement
Guest User

Untitled

a guest
Feb 20th, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.07 KB | None | 0 0
  1. // Function-as-a-differentiable-type rule:
  2. // Tangent space: ((T...) -> U...)' = Any
  3. // Cotangent space: ((T...) -> U...)'* = Any
  4. // Why? Because when a function value is varying, what's varying is it's context.
  5. // In general cases, we need this to be a constrained existential with an
  6. // `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable`
  7. // for being able to transpose between differential and a pullback.
  8.  
  9. // New associated function type calculation rules:
  10. // original: (T...) -> (U...)
  11. // jvp: (T...) -> (value: U, differential: (Any, T...') -> (U...'))
  12. // jvp: (T...) -> (value: U, pullback: (U...'*) -> (Any, T...'*))
  13.  
  14. func curry<T: Differentiable, U: Differentiable, V: Differentiable>(
  15. _ 𝑓: @escaping @differentiable (T, U) -> V
  16. ) -> @differentiable (T) -> @differentiable (U) -> V {
  17. // Outer function.
  18. let f: @differentiable (T) -> @differentiable (U) -> (V) = makeDifferentiable { x in
  19. // Inner function.
  20. let g: @differentiable (U) -> V = makeDifferentiable { y in
  21. let (z, φ٭ᶻ) = valueWithPullback(at: x, y, in: 𝑓)
  22. let φ٭ᵍ: (V.CotangentVector) -> (Any, U.CotangentVector) = { z̅ in
  23. let (x̅, y̅) = φ٭ᶻ(z̅)
  24. return (x̅ as Any, y̅)
  25. }
  26. return (value: z, pullback: φ٭ᵍ)
  27. }
  28. let φ٭ᶠ: (Any) -> (Any, T.CotangentVector) = { g̅ in
  29. return ((), g̅ as! T.CotangentVector)
  30. }
  31. return (value: g, pullback: φ٭ᶠ)
  32. }
  33. return f
  34. }
  35.  
  36.  
  37. // Turns the VJP for a thick function into a `@differentiable` function.
  38. func makeDifferentiable<T: Differentiable, U: Differentiable>(
  39. from vjp: (T) -> (value: U, pullback: (U.CotangentVector) -> (Any, T.CotangentVector))
  40. ) -> @differentiable (T) -> U {
  41. fatalError()
  42. }
  43.  
  44. // Turns the VJP for a thick function whose result is a `@differentiable` function into a `@differentiable` function.
  45. func makeDifferentiable<T: Differentiable, U: Differentiable, V: Differentiable>(
  46. from vjp: (T) -> (value: @differentiable (U) -> V, pullback: (Any) -> (Any, T.CotangentVector))
  47. ) -> @differentiable (T) -> @differentiable (U) -> V {
  48. fatalError()
  49. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement