Advertisement
Guest User

Untitled

a guest
Apr 19th, 2015
185
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 2.41 KB | None | 0 0
  1. package simpleMath.linealAlgebra.solvers
  2.  
  3. import simpleMath.SimpleProfiler
  4. import simpleMath.linealAlgebra.{MatrixD, VectorD}
  5.  
  6.  
  7. object LUPDenseSolver {
  8.   def apply(m: MatrixD) = new LUPDenseSolver(m)
  9.  
  10.   def main(args: Array[String]) {
  11.     val m: MatrixD = MatrixD(6, 6)
  12.  
  13.     for(i <- 0 until m.rows){
  14.       m(i,i) = 1
  15.       m(i, m.rows - 1 - i) = i
  16.     }
  17.     val solver = LUPDenseSolver(m)
  18.     SimpleProfiler.nanoSecs(LUPDenseSolver(m))
  19.     println(s"matrix m:$m")
  20.     println(s"LUP decomposition of m is:\n${solver.C}")
  21.     val X = VectorD((0 until m.rows).toArray)
  22.     val B = m * X
  23.     println(s"B:\n$B")
  24.     println(s"\nA.solve:\n${solver.solve(B)}")
  25.     SimpleProfiler.nanoSecs(solver.solve(B))
  26.     println(s"\nP:")
  27.     solver.P.foreach(println)
  28.   }
  29. }
  30.  
  31. class LUPDenseSolver(m: MatrixD) extends linearSolver {
  32.   if (m.rows != m.cols)
  33.     throw new IllegalArgumentException(s"for solver rows and colums size should ne equals, but we have (${m.rows}," +
  34.       s" ${m.cols}})")
  35.   var C = m.clone()
  36.   val N = m.rows
  37.   val P = (0 until N).toArray //MatrixD.identity(N)
  38.  
  39.   for (i <- 0 until N) {
  40.     var pivotValue = C(i, i)
  41.     var pivot = i
  42.  
  43.     for (row <- i until N;
  44.          temp = C(row, i).abs
  45.          if temp > pivot) {
  46.       pivotValue = temp
  47.       pivot = row
  48.     }
  49.  
  50.     if (pivotValue == 0.0)
  51.       throw new IllegalArgumentException(s"Matrix is singular C($i, $i) == 0.0")
  52.  
  53.  
  54.     //меняем местами i-ю строку и строку с опорным элементом
  55.     if (pivot != i) {
  56.  
  57.       val temp_i = P(i)
  58.       P(i) = P(pivot)
  59.       P(pivot) = temp_i
  60.  
  61.       C.swapRows(pivot, i)
  62.     }
  63.  
  64.     var j = i + 1
  65.     while (j < N) {
  66.       C(j, i) /= pivotValue
  67.       var k = i + 1
  68.       while (k < N) {
  69.         C(j, k) -= C(j, i) * C(i, k)
  70.         k += 1
  71.       }
  72.       j += 1
  73.     }
  74.   }
  75.  
  76.   override def solve(B: VectorD): VectorD = {
  77.     val X = VectorD(N)
  78.     for (i <- 0 until N)
  79.       X(i) = B(P(i))
  80.     solveUpperTriangular(solveLowTriangular(X))
  81.   }
  82.  
  83.   private def solveLowTriangular(B: VectorD): VectorD = {
  84.     for (i <- 1 until N; j <- 0 until i)
  85.       B(i) -= B(j) * C(i, j)
  86.     B
  87.   }
  88.  
  89.   private def solveUpperTriangular(B: VectorD): VectorD = {
  90.     var i = N - 1
  91.     while (i >= 0) {
  92.       var j = N - 1
  93.       while (j > i) {
  94.         B(i) -= B(j) * C(i, j)
  95.         j -= 1
  96.       }
  97.       B(i) /= C(i, j)
  98.       i -= 1
  99.     }
  100.     B
  101.   }
  102.  
  103. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement