Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.bykn.list
- import cats.Applicative
- import cats.implicits._
- /**
- * Implementation of "Purely Functional Random Access Lists" by Chris Okasaki.
- * This gives O(1) cons and uncons, and 2 log_2 N lookup.
- */
- sealed abstract class TreeList[+A] {
- def uncons: Option[(A, TreeList[A])]
- def cons[A1 >: A](a1: A1): TreeList[A1]
- def get(idx: Long): Option[A]
- def size: Long
- def foldLeft[B](init: B)(fn: (B, A) => B): B
- def foldRight[B](fin: B)(fn: (A, B) => B): B
- def map[B](fn: A => B): TreeList[B]
- def drop(n: Long): TreeList[A]
- /**
- * Split the list roughly in half
- */
- def split: (TreeList[A], TreeList[A])
- def ::[A1 >: A](a1: A1): TreeList[A1] = cons(a1)
- override def toString: String = {
- val strb = new java.lang.StringBuilder
- strb.append("TreeList(")
- def loop(first: Boolean, l: TreeList[A]): Unit =
- l.uncons match {
- case None => ()
- case Some((h, t)) =>
- if (!first) strb.append(", ")
- strb.append(h.toString)
- loop(false, t)
- }
- loop(true, this)
- strb.append(")")
- strb.toString
- }
- }
- object TreeList {
- sealed trait Nat {
- def value: Int
- }
- sealed abstract class NatEq[A <: Nat, B <: Nat] {
- def subst[F[_ <: Nat]](f: F[A]): F[B]
- }
- object NatEq {
- implicit def refl[A <: Nat]: NatEq[A, A] =
- new NatEq[A, A] {
- def subst[F[_ <: Nat]](f: F[A]): F[A] = f
- }
- }
- object Nat {
- case class Succ[P <: Nat](prev: P) extends Nat {
- val value: Int = prev.value + 1
- }
- case object Zero extends Nat {
- def value: Int = 0
- }
- def maybeEq[N1 <: Nat, N2 <: Nat](n1: N1, n2: N2): Option[NatEq[N1, N2]] =
- // I don't see how to prove this in scala, but it is true
- if (n1.value == n2.value) Some(NatEq.refl[N1].asInstanceOf[NatEq[N1, N2]])
- else None
- }
- sealed abstract class Tree[+N <: Nat, +A] {
- def value: A
- def depth: N
- def size: Long // this is 2^(depth + 1) - 1
- def get(idx: Long): Option[A]
- def map[B](fn: A => B): Tree[N, B]
- def foldRight[B](fin: B)(fn: (A, B) => B): B
- }
- case class Root[A](value: A) extends Tree[Nat.Zero.type, A] {
- def depth: Nat.Zero.type = Nat.Zero
- def size = 1L
- def get(idx: Long): Option[A] =
- if(idx == 0L) Some(value) else None
- def map[B](fn: A => B) = Root(fn(value))
- def foldRight[B](fin: B)(fn: (A, B) => B): B = fn(value, fin)
- }
- case class Balanced[N <: Nat, A](value: A, left: Tree[N, A], right: Tree[N, A]) extends Tree[Nat.Succ[N], A] {
- val depth: Nat.Succ[N] = Nat.Succ(left.depth)
- val size = 1L + left.size + right.size
- def get(idx: Long): Option[A] =
- if (idx == 0L) Some(value)
- else if (idx <= left.size) left.get(idx - 1)
- else right.get(idx - (left.size + 1))
- def map[B](fn: A => B) = Balanced[N, B](fn(value), left.map(fn), right.map(fn))
- def foldRight[B](fin: B)(fn: (A, B) => B): B = {
- val rightB = right.foldRight(fin)(fn)
- val leftB = left.foldRight(rightB)(fn)
- fn(value, leftB)
- }
- }
- def traverseTree[F[_]: Applicative, A, B, N <: Nat](ta: Tree[N, A], fn: A => F[B]): F[Tree[N, B]] =
- ta match {
- case Root(a) => fn(a).map(Root(_))
- case Balanced(a, left, right) =>
- (fn(a), traverseTree(left, fn), traverseTree(right, fn)).mapN { (b, l, r) =>
- Balanced(b, l, r)
- }
- }
- private case class Trees[A](treeList: List[Tree[Nat, A]]) extends TreeList[A] {
- def cons[A1 >: A](a1: A1): TreeList[A1] =
- treeList match {
- case h1 :: h2 :: rest =>
- def go[N1 <: Nat, N2 <: Nat, A2 <: A](t1: Tree[N1, A2], t2: Tree[N2, A2]): TreeList[A1] =
- Nat.maybeEq[N1, N2](t1.depth, t2.depth) match {
- case Some(eqv) =>
- type T[N <: Nat] = Tree[N, A2]
- Trees(Balanced[N2, A1](a1, eqv.subst[T](t1), t2) :: rest)
- case None =>
- Trees(Root(a1) :: treeList)
- }
- go(h1, h2)
- case lessThan2 => Trees(Root(a1) :: lessThan2)
- }
- def uncons: Option[(A, TreeList[A])] =
- treeList match {
- case Nil => None
- case Root(a) :: rest => Some((a, Trees(rest)))
- case Balanced(a, l, r) :: rest => Some((a, Trees(l :: r :: rest)))
- }
- def get(idx: Long): Option[A] = {
- @annotation.tailrec
- def loop(idx: Long, treeList: List[Tree[Nat, A]]): Option[A] =
- if (idx < 0L) None
- else
- treeList match {
- case Nil => None
- case h :: tail =>
- if (h.size <= idx) loop(idx - h.size, tail)
- else h.get(idx)
- }
- loop(idx, treeList)
- }
- def size: Long = {
- @annotation.tailrec
- def loop(treeList: List[Tree[Nat, A]], acc: Long): Long =
- treeList match {
- case Nil => acc
- case h :: tail => loop(tail, acc + h.size)
- }
- loop(treeList, 0L)
- }
- def foldLeft[B](init: B)(fn: (B, A) => B): B = {
- @annotation.tailrec
- def loop(init: B, rest: List[Tree[Nat, A]]): B =
- rest match {
- case Nil => init
- case Root(a) :: tail => loop(fn(init, a), tail)
- case Balanced(a, l, r) :: rest => loop(fn(init, a), l :: r :: rest)
- }
- loop(init, treeList)
- }
- def foldRight[B](fin: B)(fn: (A, B) => B): B =
- treeList.reverse.foldLeft(fin) { (b, treea) =>
- treea.foldRight(b)(fn)
- }
- def map[B](fn: A => B) = Trees(treeList.map(_.map(fn)))
- def drop(n: Long): TreeList[A] = {
- @annotation.tailrec
- def loop(n: Long, treeList: List[Tree[Nat, A]]): TreeList[A] =
- treeList match {
- case Nil => empty
- case _ if n == 0L => Trees(treeList)
- case h :: tail =>
- if (h.size <= n) loop(n - h.size, tail)
- else {
- h match {
- case Root(_) =>
- loop(n - 1, tail)
- case Balanced(a, l, r) =>
- if (n > l.size + 1L) loop(n - l.size - 1L, r :: tail)
- else if (n > 1L) loop(n - 1L, l :: r :: tail)
- else Trees(l :: r :: tail)
- }
- }
- }
- loop(n, treeList)
- }
- def split: (TreeList[A], TreeList[A]) =
- treeList match {
- case Nil => (empty, empty)
- case Root(_) :: Nil => (this, empty)
- case Balanced(a, l, r) :: Nil => (Trees(Root(a) :: l :: Nil), Trees(r :: Nil))
- case moreThanOne => (Trees(moreThanOne.init), Trees(moreThanOne.last :: Nil))
- }
- }
- implicit class InvariantTreeList[A](val treeList: TreeList[A]) extends AnyVal {
- def traverse[F[_]: Applicative, B](fn: A => F[B]): F[TreeList[B]] =
- treeList match {
- case Trees(tls) => tls.traverse { tree => traverseTree(tree, fn) }.map(Trees(_))
- }
- }
- val empty: TreeList[Nothing] = Trees[Nothing](Nil)
- def fromList[A](list: List[A]): TreeList[A] = {
- def loop(rev: List[A], acc: TreeList[A]): TreeList[A] =
- rev match {
- case Nil => acc
- case h :: tail => loop(tail, acc.cons(h))
- }
- loop(list.reverse, empty)
- }
- }
Add Comment
Please, Sign In to add comment