Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // Function-as-a-differentiable-type rule:
- // Tangent space: ((T...) -> U...)' = Any
- // Cotangent space: ((T...) -> U...)'* = Any
- // Why? Because when a function value is varying, what's varying is it's context.
- // In general cases, we need this to be a constrained existential with an
- // `AdditiveArithmetic` conformance for its `.zero` and `+`, and `Differentiable`
- // for being able to transpose between differential and a pullback.
- // New associated function type calculation rules:
- // original: (T...) -> (U...)
- // jvp: (T...) -> (value: U, differential: (Any, T...') -> (U...'))
- // jvp: (T...) -> (value: U, pullback: (U...'*) -> (Any, T...'*))
- func curry<T: Differentiable, U: Differentiable, V: Differentiable>(
- _ 𝑓: @escaping @differentiable (T, U) -> V
- ) -> @differentiable (T) -> @differentiable (U) -> V {
- // Outer function.
- let f: @differentiable (T) -> @differentiable (U) -> (V) = makeDifferentiable { x in
- // Inner function.
- let g: @differentiable (U) -> V = makeDifferentiable { y in
- let (z, φ٭ᶻ) = valueWithPullback(at: x, y, in: 𝑓)
- let φ٭ᵍ: (V.CotangentVector) -> (Any, U.CotangentVector) = { z̅ in
- let (x̅, y̅) = φ٭ᶻ(z̅)
- return (x̅ as Any, y̅)
- }
- return (value: z, pullback: φ٭ᵍ)
- }
- let φ٭ᶠ: (Any) -> (Any, T.CotangentVector) = { g̅ in
- return ((), g̅ as! T.CotangentVector)
- }
- return (value: g, pullback: φ٭ᶠ)
- }
- return f
- }
- // Turns the VJP for a thick function into a `@differentiable` function.
- func makeDifferentiable<T: Differentiable, U: Differentiable>(
- from vjp: (T) -> (value: U, pullback: (U.CotangentVector) -> (Any, T.CotangentVector))
- ) -> @differentiable (T) -> U {
- fatalError()
- }
- // Turns the VJP for a thick function whose result is a `@differentiable` function into a `@differentiable` function.
- func makeDifferentiable<T: Differentiable, U: Differentiable, V: Differentiable>(
- from vjp: (T) -> (value: @differentiable (U) -> V, pullback: (Any) -> (Any, T.CotangentVector))
- ) -> @differentiable (T) -> @differentiable (U) -> V {
- fatalError()
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement