Advertisement
Guest User

Untitled

a guest
Aug 17th, 2017
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.98 KB | None | 0 0
  1. package example
  2.  
  3. import cats.free._
  4. import cats.~>
  5. import cats.data.State
  6.  
  7. object computation {
  8.  
  9. sealed trait Edge
  10.  
  11. case class Variable(name: String, value: Double) extends Edge
  12. case class Constant(name: String, value: Double) extends Edge
  13. case class Output(name: String, value: Double) extends Edge
  14.  
  15. object Edge {
  16.  
  17. implicit class EdgeExtensions(val e: Edge) extends AnyVal {
  18. def name: String = e match {
  19. case Variable(n, _) => n
  20. case Constant(n, _) => n
  21. case Output(n, _) => n
  22. }
  23.  
  24. def value: Double = e match {
  25. case Variable(_, v) => v
  26. case Constant(_, v) => v
  27. case Output(_, v) => v
  28. }
  29. }
  30.  
  31. }
  32.  
  33. sealed trait Op[A]
  34.  
  35. case class SumOp[A](x: Edge, y: Edge) extends Op[Edge]
  36. case class SubstractOp[A](x: Edge, y: Edge) extends Op[Edge]
  37. case class MultiplyOp[A](x: Edge, y: Edge) extends Op[Edge]
  38.  
  39. type Operation = Op[Edge]
  40.  
  41. type Computation[A] = Free[Op, A]
  42.  
  43. def sum[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](SumOp(x, y))
  44.  
  45. def substract[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](SubstractOp(x, y))
  46.  
  47. def multiply[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](MultiplyOp(x, y))
  48.  
  49. object gradient {
  50.  
  51. case class ComputationData(values: Map[Operation, Double],
  52. grads: Map[Operation, Map[Variable, Double]],
  53. connections: Map[Edge, Operation])
  54.  
  55. implicit class ComputationDataExtensions(val d: ComputationData) extends AnyVal {
  56.  
  57. def connect(op: Operation, node: Edge): ComputationData =
  58. d.copy(connections = d.connections.updated(node, op))
  59.  
  60. def setValue(op: Operation, value: Double): ComputationData =
  61. d.copy(values = d.values.updated(op, value))
  62.  
  63. def setGrad(op: Operation, v: Variable, value: Double): ComputationData = {
  64. val gradsMap = d.grads.getOrElse(op, Map.empty).updated(v, value)
  65. d.copy(grads = d.grads.updated(op, gradsMap))
  66. }
  67.  
  68. def chainGrad(fromOp: Operation, toOp: Operation, f: Double => Double): ComputationData = {
  69. val grads = d.grads.getOrElse(fromOp, Map.empty).mapValues(f)
  70. d.copy(grads = d.grads.updated(toOp, grads))
  71. }
  72.  
  73. def sumGrad(op: Operation)(x: Edge): ComputationData =
  74. x -> d.connections.get(x) match {
  75. case (v: Variable, _) => d.setGrad(op, v, 1.0)
  76. case (v, Some(iop)) => d.chainGrad(iop, op, _ * 1)
  77. case _ => d
  78. }
  79.  
  80. def substractGrad(op: Operation)(x: Edge, subtrahend: Boolean): ComputationData = {
  81. val factor = if(subtrahend) -1 else 1
  82.  
  83. x -> d.connections.get(x) match {
  84. case (vx: Variable, _) => d.setGrad(op, vx, factor)
  85. case (vx, Some(fromOp)) => d.chainGrad(fromOp, op, _ * factor)
  86. case _ => d
  87. }
  88. }
  89.  
  90. def multiplyGrad(op: Operation)(x: Edge, y: Edge): ComputationData =
  91. (x, d.connections.get(x), y) match {
  92. case (vx: Variable, _, vy) => d.setGrad(op, vx, vy.value)
  93. case (vx, Some(fromOp), vy) => d.chainGrad(fromOp, op, _ * vy.value)
  94. case _ => d
  95. }
  96.  
  97. }
  98.  
  99. type ComputationState[A] = State[ComputationData, A]
  100.  
  101. private val interpreter: Op ~> ComputationState = new (Op ~> ComputationState) {
  102.  
  103. def apply[A](op: Op[A]): ComputationState[A] = op match {
  104.  
  105. case op @ SumOp(x, y) =>
  106. State {
  107. s =>
  108. val output = Output(s"(${x.name} + ${y.name})", x.value + y.value)
  109. s.setValue(op, x.value + y.value)
  110. .connect(op, output)
  111. .sumGrad(op)(x).sumGrad(op)(y) -> output
  112. }
  113. case op @ SubstractOp(x, y) =>
  114. State {
  115. s =>
  116. val output = Output(s"(${x.name} - ${y.name})", x.value - y.value)
  117. s.setValue(op, x.value - y.value)
  118. .connect(op, output)
  119. .substractGrad(op)(x, subtrahend = false).substractGrad(op)(y, subtrahend = true) -> output
  120. }
  121. case op @ MultiplyOp(x, y) =>
  122. State {
  123. s =>
  124. val output = Output(s"(${x.name} * ${y.name})", x.value * y.value)
  125. s.setValue(op, x.value * y.value)
  126. .connect(op, output)
  127. .multiplyGrad(op)(x, y).multiplyGrad(op)(y, x) -> output
  128. }
  129. }
  130. }
  131.  
  132. def apply[A](comp: Computation[A]): (ComputationData, A) =
  133. comp.foldMap(interpreter).run(
  134. ComputationData(values = Map.empty, grads = Map.empty, connections = Map.empty)).value
  135.  
  136. }
  137. }
  138.  
  139. object CompGraph {
  140.  
  141. import computation._
  142.  
  143.  
  144. def main(args: Array[String]): Unit = {
  145.  
  146. def computation[A](x1: Edge, x2: Edge, x3: Edge, const1: Edge) = for {
  147. sum1 <- substract(x1, x2)
  148. sum2 <- sum(sum1, x3)
  149. mul1 <- multiply(sum2, const1)
  150. } yield mul1
  151.  
  152. val comp = computation(Variable("x1", 1), Variable("x2", 2), Variable("x3", 3), Constant("c1", 2))
  153. val res = gradient(comp)
  154.  
  155. println(s"grads = \n$res")
  156. }
  157.  
  158. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement