Advertisement
Guest User

Untitled

a guest
Jul 31st, 2015
211
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.33 KB | None | 0 0
  1. sealed trait Expr
  2. case class N(n: Int) extends Expr
  3. case class Var(x: String, a: Int, n: Int) extends Expr
  4. case class Add(n: Expr*) extends Expr
  5. case class Mul(n: Expr*) extends Expr
  6.  
  7. def x(a: Int, n: Int): Var = {
  8. Var("x", a, n)
  9. }
  10.  
  11. def eval(e: Expr): Int = (e: @unchecked) match {
  12. case N(x) => x
  13. case Add(xs @_*) => xs.map(x => eval(x)).sum
  14. case Mul(xs @_*) => xs.map(x => eval(x)).product
  15. }
  16.  
  17. def str(e: Expr): String = e match {
  18. case N(x) => x.toString
  19. case Var(x, 1, 1) => x
  20. case Var(x, -1, 1) => "-" ++ x
  21. case Var(x, a, 1) => a.toString ++ x
  22. case Var(x, a, n) => str(Var(x, a, 1)) ++ "^" ++ n.toString
  23. case Add() => ""
  24. case Add(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")"
  25. case Add(x) => str(x)
  26. case Add(x, xs@_*)
  27. if isneg(xs.head) => str(Add(x)) ++ str(Add(xs: _*))
  28. case Add(x, xs@_*) => str(Add(x)) ++ "+" ++ str(Add(xs: _*))
  29. case Mul() => ""
  30. case Mul(Add(xs@_*)) => "(" ++ str(Add(xs: _*)) ++ ")"
  31. case Mul(Mul(xs@_*)) => "(" ++ str(Mul(xs: _*)) ++ ")"
  32. case Mul(x) => str(x)
  33. case Mul(x, xs@_*) => str(Mul(x)) ++ "*" ++ str(Mul(xs: _*))
  34. }
  35.  
  36. def isneg(e: Expr): Boolean = e match {
  37. case N(n) if n < 0 => true
  38. case Var(_, a, _) if a < 0 => true
  39. case _ => false
  40. }
  41.  
  42. def xlt(x: Expr, y: Expr): Boolean = (x,y) match {
  43. case (Var("x", _, n1), Var("x", _, n2)) => (n1 < n2)
  44. case (Var("x", _, _), _) => false
  45. case (_, Var("x", _, _)) => true
  46. case (_, _) => true
  47. }
  48.  
  49. def xsort(xs: Expr): Expr = xs match {
  50. case Add(xs@_*) => {
  51. def f(xs: List[Expr]): List[Expr] = xs match {
  52. case List() => List()
  53. case (x::xs) => {
  54. val xs1 = for (x1 <- xs if ! xlt(x1, x)) yield xsort(x1)
  55. val xs2 = for (x2 <- xs if xlt(x2, x)) yield xsort(x2)
  56. f(xs1) ++ List(x) ++ f(xs2)
  57. }
  58. }
  59. Add(f(xs.toList): _*)
  60. }
  61. case Mul(xs@_*) => Mul(xs.map(x => xsort(x)): _*)
  62. case xs => xs
  63. }
  64.  
  65. def flatten(xs: List[Expr]): List[Expr] = xs match {
  66. case List() => List()
  67. case (Add(xs1@_*)::xs2) => flatten(xs1.toList ++ xs2)
  68. case (x::xs) => x :: flatten(xs)
  69. }
  70.  
  71. def add(xs: List[Expr]): Expr = xs match {
  72. case List() => N(0)
  73. case List(xs) => xs
  74. case xs => Add(xs: _*)
  75. }
  76.  
  77. def xsimplify(xs: Expr): Expr = xs match {
  78. case Add(xs@_*) => {
  79. def getxs(xs: Expr) = (xs: @unchecked) match {
  80. case Add(xs @_*) => xs
  81. }
  82. def f(xs: List[Expr]): List[Expr] = xs match {
  83. case List() => List()
  84. case (N(0)::xs) => f(xs)
  85. case (Var(_,0,_)::xs) => f(xs)
  86. case List(x) => List(xsimplify(x))
  87. case (N(a1)::N(a2)::zs) => f(N(a1 + a2)::zs)
  88. case (Var("x",a1,n1)::Var("x",a2,n2)::zs) if n1 == n2 => f(x(a1 + a2, n1)::zs)
  89. case (x::xs) => xsimplify(x)::f(xs)
  90. }
  91. add(f(getxs(xsort(Add(flatten(xs.toList): _*))).toList))
  92. }
  93. case Mul(xs@_*) => Mul(xs.map(x => xsimplify(x)): _*)
  94. case xs => xs
  95. }
  96.  
  97. def multiply(xs1: Expr, xs2: Expr): Expr = (xs1, xs2) match {
  98. case (N(n1), N(n2)) => N(n1 * n2)
  99. case (N(n1), Var(x, a2, n2)) => Var(x, (n1 * a2), n2)
  100. case (Var(x, a1, n1), N(n2)) => Var(x, (a1 * n2), n1)
  101. case (Var(x, a1, n1), Var(y, a2, n2)) if x == y => Var(x, (a1 * a2), (n1 + n2))
  102. case (Var(x, a1, n1), Var(y, a2, n2)) if x != y => Mul(Var(x, a1, n1), Var(y, a2, n2))
  103. case (Add(xs1@_*), Add(xs2@_*)) => Add((for(x1 <- xs1; x2 <- xs2) yield multiply(x1, x2)): _*)
  104. case (Add(xs1@_*), x2) => Add((for(x1 <- xs1) yield multiply(x1, x2)): _*)
  105. case (x1, Add(xs2@_*)) => Add((for(x2 <- xs2) yield multiply(x1, x2)): _*)
  106. case (Mul(xs1@_*), Mul(xs2@_*)) => Mul(xs1.toList ++ xs2.toList:_*)
  107. case (Mul(xs1@_*), xs2) => Mul(xs1.toList :+ xs2:_*)
  108. case (xs1, Mul(xs2@_*)) => Mul(xs1 :: xs2.toList:_*)
  109. }
  110.  
  111. def mul(xs: List[Expr]): Expr = xs match {
  112. case List() => N(1)
  113. case List(xs) => xs
  114. case xs => Mul(xs: _*)
  115. }
  116.  
  117. def expand(xs: Expr): Expr = xs match {
  118. case Mul(xs@_*) => {
  119. def f(xs: List[Expr]): List[Expr] = xs match {
  120. case List() => List()
  121. case List(x) => List(expand(x))
  122. case (x::y::xs) => multiply(x,y) :: xs
  123. }
  124. Mul(f(xs.toList): _*)
  125. }
  126. case Add(xs@_*) => {
  127. def f(xs: List[Expr]): List[Expr] = xs match {
  128. case List() => List()
  129. case (x::xs) if x != expand(x) => expand(x) :: xs
  130. case (x::xs) if x == expand(x) => x :: f(xs)
  131. }
  132. Add(f(xs.toList): _*)
  133. }
  134. case xs => xs
  135. }
  136.  
  137. def expandAll(x: Expr): Expr = x match {
  138. case x if x != expand(x) => expandAll(expand(x))
  139. case x if x == expand(x) => x
  140. }
  141.  
  142. println(eval(Add(N(1),N(2))) == 1+2)
  143. println(eval(Add(N(2),N(3))) == 2+3)
  144. println(eval(Add(N(5),N(-3))) == 5-3)
  145. println(eval(Mul(N(3),N(4))) == 3*4)
  146. println(eval(Add(N(1),Mul(N(2),N(3)))) == 1+2*3)
  147. println(eval(Mul(Add(N(1),N(2)),N(3))) == (1+2)*3)
  148. println(str(Add(N(1),N(2),N(3))) == "1+2+3")
  149. println(str(Add(N(1),N(-2),N(-3))) == "1-2-3")
  150. println(str(Mul(N(1),N(2),N(3))) == "1*2*3")
  151. println(str(Add(N(1),Mul(N(2),N(3)))) == "1+2*3")
  152. println(str(Mul(N(1),N(2),N(3))) == "1*2*3")
  153. println(str(Add(Add(N(1),N(2)),N(3))) == "(1+2)+3")
  154. println(str(Mul(Add(N(1),N(2)),N(3))) == "(1+2)*3")
  155. println(str(Mul(Mul(N(1),N(2)),N(3))) == "(1*2)*3")
  156. println(Add(N(1),N(2)) == Add(N(1),N(2)))
  157. println(str(Add(x(1,1),N(1))) == "x+1")
  158. println(str(Add(x(1,3),x(-1,2),x(-2,1),N(1))) == "x^3-x^2-2x+1")
  159. val f = Mul(Add(N(5),x(2,1)),Add(x(1,2),x(1,1),N(1),x(3,3)))
  160. println(str(f) == "(5+2x)*(x^2+x+1+3x^3)")
  161. println(str(xsort(f)) == "(2x+5)*(3x^3+x^2+x+1)")
  162. val g1 = Add(x(2,1),N(3),x(4,2),x(1,1),N(1),x(1,2))
  163. println(str(g1) == "2x+3+4x^2+x+1+x^2")
  164. println(str(xsimplify(g1)) == "5x^2+3x+4")
  165. val g2 = Mul(Add(x(1,1),N(0),x(2,1)),Add(x(1,2),Add(N(1),x(2,2)),N(2)))
  166. println(str(g2) == "(x+0+2x)*(x^2+(1+2x^2)+2)")
  167. println(str(xsimplify(g2)) == "3x*(3x^2+3)")
  168. val g3 = Add(x(1,1),N(1),x(0,2),x(1,1),N(1),x(-2,1),N(-2))
  169. println(str(g3) == "x+1+0x^2+x+1-2x-2")
  170. println(str(xsimplify(g3)) == "0")
  171. println(str(N(2)) == "2")
  172. println(str(N(3)) == "3")
  173. println(str(multiply(N(2), N(3))) == "6")
  174. println(str(N(2)) == "2")
  175. println(str(x(3,2)) == "3x^2")
  176. println(str(multiply(N(2), x(3,2))) == "6x^2")
  177. println(str(x(2,3)) == "2x^3")
  178. println(str(x(3,4)) == "3x^4")
  179. println(str(multiply(x(2,3), x(3,4))) == "6x^7")
  180. println(str(N(2)) == "2")
  181. println(str(Add(x(1,1),x(2,2),N(3))) == "x+2x^2+3")
  182. println(str(multiply(N(2), Add(x(1,1),x(2,2),N(3)))) == "2x+4x^2+6")
  183. println(str(Add(x(1,1),N(1))) == "x+1")
  184. println(str(Add(x(2,1),N(3))) == "2x+3")
  185. println(str(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3)))) == "2x^2+3x+2x+3")
  186. println(str(xsimplify(multiply(Add(x(1,1),N(1)),Add(x(2,1),N(3))))) == "2x^2+5x+3")
  187. println(str(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))) == "(x+1)*(x+2)*(x+3)")
  188. println(str(expand(Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "(x^2+2x+x+2)*(x+3)")
  189. println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)")
  190. println(str(expand((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))))) == "1+(x^2+2x+x+2)*(x+3)")
  191. println(str(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3))))) == "1+(x+1)*(x+2)*(x+3)")
  192. println(str(expandAll(Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))) == "1+(x^3+3x^2+2x^2+6x+x^2+3x+2x+6)")
  193. println(str(xsimplify(expandAll((Add(N(1),Mul(Add(x(1,1),N(1)),Add(x(1,1),N(2)),Add(x(1,1),N(3)))))))) == "x^3+6x^2+11x+7")
  194. println(str(xsimplify(Add(N(1),Add(x(1,3),x(3,2),x(2,2),x(6,1),x(1,2),x(3,1),x(2,1),N(6))))) == "x^3+6x^2+11x+7")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement