Advertisement
Guest User

Forward and Reverse Mode Automatic Differenatiation in C#

a guest
May 25th, 2020
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 3.54 KB | None | 0 0
  1. using System;
  2.  
  3. namespace HigherLogics
  4. {
  5.     // Copyright 2020-05-24 Sandro Magi
  6.  
  7.     //---------------- Forward Mode Automatic Differentiation ---------------------
  8.  
  9.     public readonly struct Fwd
  10.     {
  11.         public readonly double Magnitude;
  12.         public readonly double Derivative;
  13.  
  14.         public Fwd(double mag, double deriv)
  15.         {
  16.             this.Magnitude = mag;
  17.             this.Derivative = deriv;
  18.         }
  19.  
  20.         public Fwd Pow(int k) =>
  21.             new Fwd(Math.Pow(Magnitude, k), k * Math.Pow(Magnitude, k - 1) * Derivative);
  22.  
  23.         public static Fwd operator +(Fwd lhs, Fwd rhs) =>
  24.             new Fwd(lhs.Magnitude + rhs.Magnitude, lhs.Derivative + rhs.Derivative);
  25.  
  26.         public static Fwd operator *(Fwd lhs, Fwd rhs) =>
  27.             new Fwd(lhs.Magnitude + rhs.Magnitude, lhs.Derivative * rhs.Magnitude + rhs.Derivative * lhs.Magnitude);
  28.  
  29.         public static Func<double, Fwd> Differentiate(Func<Fwd, Fwd> f) =>
  30.             x => f(new Fwd(x, 1));
  31.  
  32.         public static Func<double, double, Fwd> DifferentiateX0(Func<Fwd, Fwd, Fwd> f) =>
  33.             (x0, x1) => f(new Fwd(x0, 1), new Fwd(x1, 0));
  34.  
  35.         public static Func<double, double, Fwd> DifferentiateX1(Func<Fwd, Fwd, Fwd> f) =>
  36.             (x0, x1) => f(new Fwd(x0, 0), new Fwd(x1, 1));
  37.     }
  38.  
  39.  
  40.  
  41.     //----------------- Reverse Mode Automatic Differentiation -------------------
  42.  
  43.     public readonly struct Rev
  44.     {
  45.         public readonly double Magnitude;
  46.         readonly Action<double> Derivative;
  47.  
  48.         public Rev(double y, Action<double> dy)
  49.         {
  50.             this.Magnitude = y;
  51.             this.Derivative = dy;
  52.         }
  53.  
  54.         public Rev Pow(int e)
  55.         {
  56.             var x = Magnitude;
  57.             var k = Derivative;
  58.             return new Rev(Math.Pow(Magnitude, e), dx => k(e * Math.Pow(x, e - 1) * dx));
  59.         }
  60.  
  61.         public static Rev operator +(Rev lhs, Rev rhs) =>
  62.             new Rev(lhs.Magnitude + rhs.Magnitude, dx =>
  63.             {
  64.                 lhs.Derivative(dx);
  65.                 rhs.Derivative(dx);
  66.             });
  67.  
  68.         public static Rev operator *(Rev lhs, Rev rhs) =>
  69.             new Rev(lhs.Magnitude * rhs.Magnitude,
  70.                    dx =>
  71.                    {
  72.                        lhs.Derivative(dx * rhs.Magnitude);
  73.                        rhs.Derivative(dx * lhs.Magnitude);
  74.                    });
  75.  
  76.         public static Func<double, (double, double)> Differentiate(Func<Rev, Rev> f) =>
  77.             x =>
  78.             {
  79.                 double dx = 1;
  80.                 var y = f(new Rev(x, dy => dx = dy));
  81.                 y.Derivative(1);
  82.                 return (y.Magnitude, dx);
  83.             };
  84.  
  85.         public static Func<double, double, (double, double, double)> Differentiate(Func<Rev, Rev, Rev> f) =>
  86.             (x0, x1) =>
  87.             {
  88.                 double dx0 = 1, dx1 = 1;
  89.                 var y = f(new Rev(x0, dy => dx0 = dy), new Rev(x1, dy => dx1 = dy));
  90.                 y.Derivative(1);
  91.                 return (y.Magnitude, dx0, dx1);
  92.             };
  93.  
  94.         public static Func<double, double, double, (double, double, double, double)> Differentiate(Func<Rev, Rev, Rev, Rev> f) =>
  95.             (x0, x1, x2) =>
  96.             {
  97.                 double dx0 = -1, dx1 = -1, dx2 = -1;
  98.                 var y = f(new Rev(x0, dy => dx0 = dy),
  99.                           new Rev(x1, dy => dx1 = dy),
  100.                           new Rev(x2, dy => dx2 = dy));
  101.                 y.Derivative(1);
  102.                 return (y.Magnitude, dx0, dx1, dx2);
  103.             };
  104.     }
  105. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement