Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- //> using scala 3.5.2
- //> using options -Wall
- import java.util.concurrent.atomic.AtomicInteger
- import java.util.concurrent.Executors
- import java.util.{Timer, TimerTask}
- import scala.concurrent.duration._
- import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Promise}
- import scala.util.chaining._
- import scala.util.{Failure, Success, Try}
- class Context(private val timer: Timer, private val threadPool: ExecutionContext):
- def schedule(task: => Unit, fd: FiniteDuration): Unit =
- val timerTask = new TimerTask:
- def run(): Unit = sendToPool(task)
- timer.schedule(timerTask, fd.toMillis)
- def sendToPool[A](task: => A): Unit = threadPool.execute { () =>
- val _ = task
- }
- sealed trait IO[A] extends Product with Serializable:
- /** Implemented in terms of [[runAsync]] */
- def runSync(
- )(
- implicit context: Context
- ): Try[A] =
- val p = Promise[A]()
- this.runAsync(p.complete)
- Await.ready(p.future, Duration.Inf).value.get
- /** @param callback to handle a result of a computation, guaranteed to be invoked only once. */
- def runAsync(
- callback: Try[A] => Unit
- )(
- implicit context: Context
- ): Unit =
- IO.runLoop(this)(callback.asInstanceOf[Try[?] => Unit])
- def map[B](f: A => B): IO[B] = IO.Map(this, f)
- def as[B](b: B): IO[B] = IO.Map(this, (_: A) => b)
- def flatMap[B](f: A => IO[B]): IO[B] = IO.FlatMap(this, f)
- def >>[B](f: IO[B]): IO[B] = IO.FlatMap(this, (_: A) => f)
- def recover(f: PartialFunction[Throwable, A]): IO[A] = IO.Recover(this, f)
- def recoverWith(f: PartialFunction[Throwable, IO[A]]): IO[A] = IO.RecoverWith(this, f)
- def fork(): IO[IO.Fiber[A]] = IO.Fork(this)
- object IO:
- def apply[A](thunk: => A): IO[A] = IO.Delay(() => thunk)
- def pure[A](a: A): IO[A] = IO.Pure(a)
- def async[A](cb: (Try[A] => Unit) => Unit): IO[A] = IO.Async(cb)
- def sleep(duration: FiniteDuration): IO[Unit] = IO.Sleep(duration)
- def raiseError(e: Throwable): IO[Unit] = IO.Error(e)
- def putStrLn(s: String): IO[Unit] = IO(println(s"${Thread.currentThread().getName}: $s"))
- def shift: IO[Unit] = IO.Shift
- final private case class Pure[A](value: A) extends IO[A]
- final private case class Delay[A](thunk: () => A) extends IO[A]
- final private case class Async[A](callback: (Try[A] => Unit) => Unit) extends IO[A]
- final private case class FlatMap[A, B](prev: IO[A], f: A => IO[B]) extends IO[B]
- final private case class Map[A, B](prev: IO[A], f: A => B) extends IO[B]
- final private case class Recover[A](prev: IO[A], f: PartialFunction[Throwable, A]) extends IO[A]
- final private case class RecoverWith[A](prev: IO[A], f: PartialFunction[Throwable, IO[A]]) extends IO[A]
- final private case class Error(e: Throwable) extends IO[Unit]
- final private case class Sleep[A](duration: FiniteDuration) extends IO[A]
- final private case class Fork[A](io: IO[A]) extends IO[Fiber[A]]
- final private case class Join[A](fiber: Fiber[A]) extends IO[A]
- private case object Shift extends IO[Unit]
- class Fiber[A]:
- private var callbacks = Set.empty[Try[A] => Unit]
- private var result = Option.empty[Try[A]]
- def join(): IO[A] = IO.Join(this)
- private[IO] def register(cb: Try[A] => Unit): Unit =
- synchronized:
- result match
- // To ensure the callback invoked only once.
- case Some(value) => cb(value)
- case None => callbacks = callbacks + cb
- private[IO] def finish(res: Try[A]): Unit =
- synchronized:
- // To ensure the callback invoked only once.
- result = Some(res)
- callbacks.foreach(_(res))
- callbacks = Set.empty
- /** Optimised for maximum throughput, fairness must be ensured by the end developer by using [[IO.shift]]. */
- private def runLoop(
- io: IO[?]
- )(
- done: Try[?] => Unit
- )(
- implicit context: Context
- ): Unit =
- // Evaluation should run in an intended thread pool since the beginning.
- // Otherwise, the first computations would run in a default 'main' JVM thread.
- context.sendToPool(eval(io)(done))
- private def eval(
- io: IO[?]
- )(
- done: Try[?] => Unit
- )(
- implicit context: Context
- ): Unit =
- io match
- case IO.Pure(value) => done(Success(value))
- case IO.Delay(thunk) => done(Success(thunk()))
- case IO.Async(asyncTaskDefinition) => asyncTaskDefinition(done)
- case IO.FlatMap(prev, f) =>
- eval(prev):
- case Success(value) => eval(f.asInstanceOf[Any => IO[?]](value))(done)
- case x => done(x)
- case IO.Map(prev, f) => eval(prev)(res => done(res.map(f.asInstanceOf[Any => Any])))
- case IO.Recover(prev, f) => eval(prev)(res => done(res.recover(f)))
- case IO.RecoverWith(prev, f) =>
- eval(prev):
- case Failure(e) if f.isDefinedAt(e) => eval(f(e))(done)
- case x => done(x)
- case IO.Error(e) => done(Failure(e))
- case IO.Sleep(duration) => context.schedule(done(Success(())), duration)
- case _: IO.Fork[_] =>
- val fiber = new Fiber[Any] {}
- context.sendToPool(eval(io.asInstanceOf[IO.Fork[?]].io)(fiber.finish))
- done(Success(fiber))
- case IO.Join(fiber) => fiber.register(done)
- case IO.Shift => context.sendToPool(eval(io)(done))
- object Main:
- def main(args: Array[String]): Unit =
- val program: IO[Unit] = (for {
- fiber <- (IO.sleep(1.second) >> IO.putStrLn("1") >> IO.sleep(3.second) >> IO.putStrLn("2") >> IO.pure(
- 42
- )).fork()
- _ <- IO.putStrLn("3")
- value <- fiber.join()
- value2 <- fiber.join()
- _ <- IO.raiseError(new RuntimeException(s"Boom! $value $value2"))
- } yield ()).recoverWith { case e => IO.putStrLn(e.getMessage) }
- implicit val context: Context =
- val timer: Timer = new Timer("Pet IO Timer", true)
- val counter = new AtomicInteger(0)
- val nThreads = 1
- val pool: ExecutionContextExecutor = ExecutionContext.fromExecutor(
- Executors.newFixedThreadPool(
- nThreads,
- (r: Runnable) =>
- new Thread(r).tap(_.setDaemon(true)).tap(_.setName(s"Pet IO ${counter.getAndIncrement()}"))
- )
- )
- new Context(timer, pool)
- println(program.runSync())
Advertisement
Add Comment
Please, Sign In to add comment