Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package rs
- import scala.language.higherKinds
- import matryoshka.data._
- import matryoshka.implicits._
- import scalaz._, Scalaz._
- trait Expr[A]
- case class NumLit[A](value: Int) extends Expr[A]
- case class Add[A](left: A, right: A) extends Expr[A]
- case class Div[A](num: A, denum: A) extends Expr[A]
- case class DivisionByZero(div: Div[Int])
- object Expr {
- implicit val traverse: Traverse[Expr] = new Traverse[Expr] {
- def traverseImpl[G[_] : Applicative, A, B](expr: Expr[A])(f: A => G[B]): G[Expr[B]] = expr match {
- case NumLit(value) => (NumLit(value): Expr[B]).point[G]
- case Add(left, right) => (f(left) |@| f(right)) (Add(_, _))
- case Div(num, denum) => (f(num) |@| f(denum)) (Div(_, _))
- }
- }
- }
- object Main extends App {
- def complexity(expr: Fix[Expr]): Int = expr.cata[Int] {
- case NumLit(value) => 1
- case Add(left, right) => 1 + Math.max(left, right)
- case Div(num, denum) => 1 + Math.max(num, denum)
- }
- def incr(expr: Fix[Expr]): Fix[Expr] = expr.cata[Fix[Expr]] {
- case NumLit(value) => Fix(NumLit(value + 1))
- case other => Fix(other)
- }
- type ErrorOr[A] = DivisionByZero \/ A
- def eval(expr: Fix[Expr]): ErrorOr[Int] = expr.cataM[ErrorOr, Int] {
- case NumLit(value) => value.right
- case Add(left, right) => (left + right).right
- case d@Div(num, denum) => if (denum != 0) (num / denum).right else DivisionByZero(d).left
- }
- def collect(a: Fix[Expr]): List[NumLit[_]] = a.cata[List[NumLit[_]]] {
- case n@NumLit(value) => List(n)
- case other => other.fold
- }
- def gen(complexity: Int): Option[Fix[Expr]] = {
- complexity.anaM[Fix[Expr]][Option, Expr] {
- case 0 => None
- case 1 => Some(NumLit(1))
- case n => Some(Add(n - 1, n - 1))
- }
- }
- val expr: Fix[Expr] = Fix(Add(Fix(NumLit(5)), Fix(NumLit(10))))
- println("expr := " + expr)
- println("complexity := " + complexity(expr))
- println("incr := " + incr(expr))
- println("eval 1+1 := " + eval(Fix(Add(Fix(NumLit(1)), Fix(NumLit(1))))))
- println("eval 6/3 := " + eval(Fix(Div(Fix(NumLit(6)), Fix(NumLit(3))))))
- println("eval 5/0 := " + eval(Fix(Div(Fix(NumLit(5)), Fix(NumLit(0))))))
- println("collect := " + collect(Fix(Add(Fix(NumLit(1)), Fix(Add(Fix(NumLit(2)), Fix(NumLit(3))))))))
- println("gen(0) := " + gen(0).map(complexity))
- println("gen(1) := " + gen(1).map(complexity))
- println("gen(2) := " + gen(2).map(complexity))
- println("gen(3) := " + gen(3).map(complexity))
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement