Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package example
- import cats.free._
- import cats.~>
- import cats.data.State
- object computation {
- sealed trait Edge
- case class Variable(name: String, value: Double) extends Edge
- case class Constant(name: String, value: Double) extends Edge
- case class Output(name: String, value: Double) extends Edge
- object Edge {
- implicit class EdgeExtensions(val e: Edge) extends AnyVal {
- def name: String = e match {
- case Variable(n, _) => n
- case Constant(n, _) => n
- case Output(n, _) => n
- }
- def value: Double = e match {
- case Variable(_, v) => v
- case Constant(_, v) => v
- case Output(_, v) => v
- }
- }
- }
- sealed trait Op[A]
- case class SumOp[A](x: Edge, y: Edge) extends Op[Edge]
- case class SubstractOp[A](x: Edge, y: Edge) extends Op[Edge]
- case class MultiplyOp[A](x: Edge, y: Edge) extends Op[Edge]
- type Operation = Op[Edge]
- type Computation[A] = Free[Op, A]
- def sum[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](SumOp(x, y))
- def substract[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](SubstractOp(x, y))
- def multiply[A](x: Edge, y: Edge): Computation[Edge] = Free.liftF[Op, Edge](MultiplyOp(x, y))
- object gradient {
- case class ComputationData(values: Map[Operation, Double],
- grads: Map[Operation, Map[Variable, Double]],
- connections: Map[Edge, Operation])
- implicit class ComputationDataExtensions(val d: ComputationData) extends AnyVal {
- def connect(op: Operation, node: Edge): ComputationData =
- d.copy(connections = d.connections.updated(node, op))
- def setValue(op: Operation, value: Double): ComputationData =
- d.copy(values = d.values.updated(op, value))
- def setGrad(op: Operation, v: Variable, value: Double): ComputationData = {
- val gradsMap = d.grads.getOrElse(op, Map.empty).updated(v, value)
- d.copy(grads = d.grads.updated(op, gradsMap))
- }
- def chainGrad(fromOp: Operation, toOp: Operation, f: Double => Double): ComputationData = {
- val grads = d.grads.getOrElse(fromOp, Map.empty).mapValues(f)
- d.copy(grads = d.grads.updated(toOp, grads))
- }
- def sumGrad(op: Operation)(x: Edge): ComputationData =
- x -> d.connections.get(x) match {
- case (v: Variable, _) => d.setGrad(op, v, 1.0)
- case (v, Some(iop)) => d.chainGrad(iop, op, _ * 1)
- case _ => d
- }
- def substractGrad(op: Operation)(x: Edge, subtrahend: Boolean): ComputationData = {
- val factor = if(subtrahend) -1 else 1
- x -> d.connections.get(x) match {
- case (vx: Variable, _) => d.setGrad(op, vx, factor)
- case (vx, Some(fromOp)) => d.chainGrad(fromOp, op, _ * factor)
- case _ => d
- }
- }
- def multiplyGrad(op: Operation)(x: Edge, y: Edge): ComputationData =
- (x, d.connections.get(x), y) match {
- case (vx: Variable, _, vy) => d.setGrad(op, vx, vy.value)
- case (vx, Some(fromOp), vy) => d.chainGrad(fromOp, op, _ * vy.value)
- case _ => d
- }
- }
- type ComputationState[A] = State[ComputationData, A]
- private val interpreter: Op ~> ComputationState = new (Op ~> ComputationState) {
- def apply[A](op: Op[A]): ComputationState[A] = op match {
- case op @ SumOp(x, y) =>
- State {
- s =>
- val output = Output(s"(${x.name} + ${y.name})", x.value + y.value)
- s.setValue(op, x.value + y.value)
- .connect(op, output)
- .sumGrad(op)(x).sumGrad(op)(y) -> output
- }
- case op @ SubstractOp(x, y) =>
- State {
- s =>
- val output = Output(s"(${x.name} - ${y.name})", x.value - y.value)
- s.setValue(op, x.value - y.value)
- .connect(op, output)
- .substractGrad(op)(x, subtrahend = false).substractGrad(op)(y, subtrahend = true) -> output
- }
- case op @ MultiplyOp(x, y) =>
- State {
- s =>
- val output = Output(s"(${x.name} * ${y.name})", x.value * y.value)
- s.setValue(op, x.value * y.value)
- .connect(op, output)
- .multiplyGrad(op)(x, y).multiplyGrad(op)(y, x) -> output
- }
- }
- }
- def apply[A](comp: Computation[A]): (ComputationData, A) =
- comp.foldMap(interpreter).run(
- ComputationData(values = Map.empty, grads = Map.empty, connections = Map.empty)).value
- }
- }
- object CompGraph {
- import computation._
- def main(args: Array[String]): Unit = {
- def computation[A](x1: Edge, x2: Edge, x3: Edge, const1: Edge) = for {
- sum1 <- substract(x1, x2)
- sum2 <- sum(sum1, x3)
- mul1 <- multiply(sum2, const1)
- } yield mul1
- val comp = computation(Variable("x1", 1), Variable("x2", 2), Variable("x3", 3), Constant("c1", 2))
- val res = gradient(comp)
- println(s"grads = \n$res")
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement