Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import scala.annotation.tailrec
- import scala.util.control.TailCalls
- import scala.util.control.TailCalls._
- import scalaz.Applicative
- import scalaz.Free._
- import scalaz.Monad
- import scalaz.OptionT
- /*
- * A possible limitation with writing monadic code in Scala. Because of Scala's
- * limited tail recursion optimisation (as opposed to Haskell's more general
- * tail *call* optimisation) some recursive calls in monadic code results in
- * growing the stack.
- *
- * As a demonstration, try running MonadTest.subtract(Some(1000000), Some(1000000))
- * or running MonadTest.flatSubtract(Some(1000000), Some(1000000)).
- * This will result in a StackOverflowException.
- *
- * If we try adding the @tailrec annotation to subtract or flatSubtract the
- * code will not compile, indicating that Scala doesn't optimise these
- * functions.
- *
- * As we can see in the implementation of flatSubtract, the recursive call to
- * flatSubtract is within two nested flatMap calls, so the flatMap call is not
- * a tail call, and flatSubtract is not tail recursive.
- */
- object MonadTest {
- // @tailrec
- def subtract(ma : Option[Int], mb : Option[Int]) : Option[Int] = {
- for {
- a <- ma
- b <- mb
- r <- if (b == 0)
- Some(a)
- else if (b < 0)
- None
- else
- subtract(Some(a - 1), Some(b - 1))
- } yield r
- }
- // @tailrec
- def flatSubtract(ma: Option[Int], mb: Option[Int]) : Option[Int] = {
- ma.flatMap{ a =>
- mb.flatMap{ b =>
- if (b == 0)
- Some(a)
- else if (b < 0)
- None
- else
- flatSubtract(Some(a - 1), Some(b - 1))
- }
- }
- }
- def trampolineSubtract(ma: OptionT[Trampoline, Int], mb: OptionT[Trampoline, Int]): OptionT[Trampoline, Int] = {
- for {
- a <- ma
- b <- mb
- r <- if (b == 0)
- OptionT.some[Trampoline, Int](a)
- else if (b < 0)
- OptionT.none[Trampoline, Int]
- else {
- val na = OptionT.some[Trampoline, Int](a - 1)
- val nb = OptionT.some[Trampoline, Int](b - 1)
- trampolineSubtract(na, nb)
- }
- } yield r
- }
- def freeSubtract(ma: Option[Int], mb: Option[Int]): Option[Int] = {
- for {
- a <- ma
- b <- mb
- na = OptionT.some[Trampoline, Int](a)
- nb = OptionT.some[Trampoline, Int](b)
- r <- trampolineSubtract(na, nb).run.run
- } yield r
- }
- implicit val tailrecInstance: Monad[TailRec] =
- new Monad[TailRec] {
- override def point[A](a: => A) = TailCalls.done(a)
- override def bind[A, B](ta: TailRec[A])(f: A => TailRec[B]) = ta flatMap f
- }
- def tailrecSubtract(ma: OptionT[TailRec, Int], mb: OptionT[TailRec, Int]): OptionT[TailRec, Int] = {
- for {
- a <- ma
- b <- mb
- r <- if (b == 0)
- OptionT.some[TailRec, Int](a)
- else if (b < 0)
- OptionT.none[TailRec, Int]
- else {
- val na = OptionT.some[TailRec, Int](a - 1)
- val nb = OptionT.some[TailRec, Int](b - 1)
- tailrecSubtract(na, nb)
- }
- } yield r
- }
- def tailcallSubtract(ma: Option[Int], mb: Option[Int]): Option[Int] = {
- for {
- a <- ma
- b <- mb
- na = OptionT.some[TailRec, Int](a)
- nb = OptionT.some[TailRec, Int](b)
- r <- tailrecSubtract(na, nb).run.result
- } yield r
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement