Advertisement
Terrah

matrix.lua

Jul 22nd, 2015
476
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 32.89 KB | None | 0 0
  1. --[[
  2.  
  3. LUA MODULE
  4.  
  5.   matrix v$(_VERSION) - matrix functions implemented with Lua tables
  6.    
  7. SYNOPSIS
  8.  
  9.   local matrix = require 'matrix'
  10.   m1 = matrix{{8,4,1},{6,8,3}}
  11.   m2 = matrix{{-8,1,3},{5,2,1}}
  12.   assert(m1 + m2 == matrix{{0,5,4},{11,10,4}})
  13.  
  14. DESCRIPTION
  15.  
  16.   With simple matrices this script is quite useful, though for more
  17.   exact calculations, one would probably use a program like Matlab instead.
  18.   Matrices of size 100x100 can still be handled very well.
  19.   The error for the determinant and the inverted matrix is around 10^-9
  20.   with a 100x100 matrix and an element range from -100 to 100.
  21.    
  22.    Characteristics:
  23.    
  24.     - functions called via matrix.<function> should be able to handle
  25.       any table matrix of structure t[i][j] = value
  26.     - can handle a type of complex matrix
  27.     - can handle symbolic matrices. (Symbolic matrices cannot be
  28.       used with complex matrices.)
  29.     - arithmetic functions do not change the matrix itself
  30.       but build and return a new matrix
  31.     - functions are intended to be light on checks
  32.       since one gets a Lua error on incorrect use anyways
  33.     - uses mainly Gauss-Jordan elimination
  34.     - for Lua tables optimised determinant calculation (fast)
  35.       but not invoking any checks for special types of matrices
  36.     - vectors can be set up via vec1 = matrix{{ 1,2,3 }}^'T' or matrix{1,2,3}
  37.     - vectors can be multiplied to a scalar via num = vec1^'T' * vec2
  38.       where num will be a matrix with the result in mtx[1][1],
  39.       or use num = vec1:scalar( vec2 ), where num is a number
  40.  
  41. API
  42.    
  43.     matrix function list:
  44.  
  45.     matrix.add
  46.     matrix.columns
  47.     matrix.concath
  48.     matrix.concatv
  49.     matrix.copy
  50.     matrix.cross
  51.     matrix.det
  52.     matrix.div
  53.     matrix.divnum
  54.     matrix.dogauss
  55.     matrix.elementstostring
  56.     matrix.getelement
  57.     matrix.gsub
  58.     matrix.invert
  59.     matrix.ipairs
  60.     matrix.latex
  61.     matrix.len
  62.     matrix.mul
  63.     matrix.mulnum
  64.     matrix:new
  65.     matrix.normf
  66.     matrix.normmax
  67.     matrix.pow
  68.     matrix.print
  69.     matrix.random
  70.     matrix.replace
  71.     matrix.root
  72.     matrix.rotl
  73.     matrix.rotr
  74.     matrix.round
  75.     matrix.rows
  76.     matrix.scalar
  77.     matrix.setelement
  78.     matrix.size
  79.     matrix.solve
  80.     matrix.sqrt
  81.     matrix.sub
  82.     matrix.subm
  83.     matrix.tostring
  84.     matrix.transpose
  85.     matrix.type
  86.    
  87.     See code and test_matrix.lua.
  88.  
  89. DEPENDENCIES
  90.  
  91.   None (other than Lua 5.1 or 5.2).  May be used with complex.lua.
  92.  
  93. HOME PAGE
  94.  
  95.   http://luamatrix.luaforge.net
  96.   http://lua-users.org/wiki/LuaMatrix
  97.  
  98. DOWNLOAD/INSTALL
  99.  
  100.   ./util.mk
  101.   cd tmp/*
  102.   luarocks make
  103.  
  104. LICENSE
  105.  
  106.   Licensed under the same terms as Lua itself.
  107.    
  108.   Developers:
  109.     Michael Lutz (chillcode) - original author
  110.     David Manura http://lua-users.org/wiki/DavidManura
  111. --]]
  112.  
  113. --////////////
  114. --// matrix //
  115. --////////////
  116.  
  117. local matrix = {_TYPE='module', _NAME='matrix', _VERSION='0.2.11.20120416'}
  118.  
  119. -- access to the metatable we set at the end of the file
  120. local matrix_meta = {}
  121.  
  122. --/////////////////////////////
  123. --// Get 'new' matrix object //
  124. --/////////////////////////////
  125.  
  126. --// matrix:new ( rows [, columns [, value]] )
  127. -- if rows is a table then sets rows as matrix
  128. -- if rows is a table of structure {1,2,3} then it sets it as a vector matrix
  129. -- if rows and columns are given and are numbers, returns a matrix with size rowsxcolumns
  130. -- if num is given then returns a matrix with given size and all values set to num
  131. -- if rows is given as number and columns is "I", will return an identity matrix of size rowsxrows
  132. function matrix:new( rows, columns, value )
  133.     -- check for given matrix
  134.     if type( rows ) == "table" then
  135.         -- check for vector
  136.         if type(rows[1]) ~= "table" then -- expect a vector
  137.             return setmetatable( {{rows[1]},{rows[2]},{rows[3]}},matrix_meta )
  138.         end
  139.         return setmetatable( rows,matrix_meta )
  140.     end
  141.     -- get matrix table
  142.     local mtx = {}
  143.     local value = value or 0
  144.     -- build identity matrix of given rows
  145.     if columns == "I" then
  146.         for i = 1,rows do
  147.             mtx[i] = {}
  148.             for j = 1,rows do
  149.                 if i == j then
  150.                     mtx[i][j] = 1
  151.                 else
  152.                     mtx[i][j] = 0
  153.                 end
  154.             end
  155.         end
  156.     -- build new matrix
  157.     else
  158.         for i = 1,rows do
  159.             mtx[i] = {}
  160.             for j = 1,columns do
  161.                 mtx[i][j] = value
  162.             end
  163.         end
  164.     end
  165.     -- return matrix with shared metatable
  166.     return setmetatable( mtx,matrix_meta )
  167. end
  168.  
  169. --// matrix ( rows [, comlumns [, value]] )
  170. -- set __call behaviour of matrix
  171. -- for matrix( ... ) as matrix.new( ... )
  172. setmetatable( matrix, { __call = function( ... ) return matrix.new( ... ) end } )
  173.  
  174.  
  175. -- functions are designed to be light on checks
  176. -- so we get Lua errors instead on wrong input
  177. -- matrix.<functions> should handle any table of structure t[i][j] = value
  178. -- we always return a matrix with scripts metatable
  179. -- cause its faster than setmetatable( mtx, getmetatable( input matrix ) )
  180.  
  181. --///////////////////////////////
  182. --// matrix 'matrix' functions //
  183. --///////////////////////////////
  184.  
  185. --// for real, complex and symbolic matrices //--
  186.  
  187. -- note: real and complex matrices may be added, subtracted, etc.
  188. --      real and symbolic matrices may also be added, subtracted, etc.
  189. --      but one should avoid using symbolic matrices with complex ones
  190. --      since it is not clear which metatable then is used
  191.  
  192. --// matrix.add ( m1, m2 )
  193. -- Add two matrices; m2 may be of bigger size than m1
  194. function matrix.add( m1, m2 )
  195.     local mtx = {}
  196.     for i = 1,#m1 do
  197.         local m3i = {}
  198.         mtx[i] = m3i
  199.         for j = 1,#m1[1] do
  200.             m3i[j] = m1[i][j] + m2[i][j]
  201.         end
  202.     end
  203.     return setmetatable( mtx, matrix_meta )
  204. end
  205.  
  206. --// matrix.sub ( m1 ,m2 )
  207. -- Subtract two matrices; m2 may be of bigger size than m1
  208. function matrix.sub( m1, m2 )
  209.     local mtx = {}
  210.     for i = 1,#m1 do
  211.         local m3i = {}
  212.         mtx[i] = m3i
  213.         for j = 1,#m1[1] do
  214.             m3i[j] = m1[i][j] - m2[i][j]
  215.         end
  216.     end
  217.     return setmetatable( mtx, matrix_meta )
  218. end
  219.  
  220. --// matrix.mul ( m1, m2 )
  221. -- Multiply two matrices; m1 columns must be equal to m2 rows
  222. -- e.g. #m1[1] == #m2
  223. function matrix.mul( m1, m2 )
  224.     -- multiply rows with columns
  225.     local mtx = {}
  226.     for i = 1,#m1 do
  227.         mtx[i] = {}
  228.         for j = 1,#m2[1] do
  229.             local num = m1[i][1] * m2[1][j]
  230.             for n = 2,#m1[1] do
  231.                 num = num + m1[i][n] * m2[n][j]
  232.             end
  233.             mtx[i][j] = num
  234.         end
  235.     end
  236.     return setmetatable( mtx, matrix_meta )
  237. end
  238.  
  239. --//  matrix.div ( m1, m2 )
  240. -- Divide two matrices; m1 columns must be equal to m2 rows
  241. -- m2 must be square, to be inverted,
  242. -- if that fails returns the rank of m2 as second argument
  243. -- e.g. #m1[1] == #m2; #m2 == #m2[1]
  244. function matrix.div( m1, m2 )
  245.     local rank; m2,rank = matrix.invert( m2 )
  246.     if not m2 then return m2, rank end -- singular
  247.     return matrix.mul( m1, m2 )
  248. end
  249.  
  250. --// matrix.mulnum ( m1, num )
  251. -- Multiply matrix with a number
  252. -- num may be of type 'number' or 'complex number'
  253. -- strings get converted to complex number, if that fails then to symbol
  254. function matrix.mulnum( m1, num )
  255.     local mtx = {}
  256.     -- multiply elements with number
  257.     for i = 1,#m1 do
  258.         mtx[i] = {}
  259.         for j = 1,#m1[1] do
  260.             mtx[i][j] = m1[i][j] * num
  261.         end
  262.     end
  263.     return setmetatable( mtx, matrix_meta )
  264. end
  265.  
  266. --// matrix.divnum ( m1, num )
  267. -- Divide matrix by a number
  268. -- num may be of type 'number' or 'complex number'
  269. -- strings get converted to complex number, if that fails then to symbol
  270. function matrix.divnum( m1, num )
  271.     local mtx = {}
  272.     -- divide elements by number
  273.     for i = 1,#m1 do
  274.         local mtxi = {}
  275.         mtx[i] = mtxi
  276.         for j = 1,#m1[1] do
  277.             mtxi[j] = m1[i][j] / num
  278.         end
  279.     end
  280.     return setmetatable( mtx, matrix_meta )
  281. end
  282.  
  283.  
  284. --// for real and complex matrices only //--
  285.  
  286. --// matrix.pow ( m1, num )
  287. -- Power of matrix; mtx^(num)
  288. -- num is an integer and may be negative
  289. -- m1 has to be square
  290. -- if num is negative and inverting m1 fails
  291. -- returns the rank of matrix m1 as second argument
  292. function matrix.pow( m1, num )
  293.     assert(num == math.floor(num), "exponent not an integer")
  294.     if num == 0 then
  295.         return matrix:new( #m1,"I" )
  296.     end
  297.     if num < 0 then
  298.         local rank; m1,rank = matrix.invert( m1 )
  299.       if not m1 then return m1, rank end -- singular
  300.         num = -num
  301.     end
  302.     local mtx = matrix.copy( m1 )
  303.     for i = 2,num   do
  304.         mtx = matrix.mul( mtx,m1 )
  305.     end
  306.     return mtx
  307. end
  308.  
  309. local function number_norm2(x)
  310.   return x * x
  311. end
  312.  
  313. --// matrix.det ( m1 )
  314. -- Calculate the determinant of a matrix
  315. -- m1 needs to be square
  316. -- Can calc the det for symbolic matrices up to 3x3 too
  317. -- The function to calculate matrices bigger 3x3
  318. -- is quite fast and for matrices of medium size ~(100x100)
  319. -- and average values quite accurate
  320. -- here we try to get the nearest element to |1|, (smallest pivot element)
  321. -- os that usually we have |mtx[i][j]/subdet| > 1 or mtx[i][j];
  322. -- with complex matrices we use the complex.abs function to check if it is bigger or smaller
  323. function matrix.det( m1 )
  324.  
  325.     -- check if matrix is quadratic
  326.     assert(#m1 == #m1[1], "matrix not square")
  327.    
  328.     local size = #m1
  329.    
  330.     if size == 1 then
  331.         return m1[1][1]
  332.     end
  333.    
  334.     if size == 2 then
  335.         return m1[1][1]*m1[2][2] - m1[2][1]*m1[1][2]
  336.     end
  337.    
  338.     if size == 3 then
  339.         return ( m1[1][1]*m1[2][2]*m1[3][3] + m1[1][2]*m1[2][3]*m1[3][1] + m1[1][3]*m1[2][1]*m1[3][2]
  340.             - m1[1][3]*m1[2][2]*m1[3][1] - m1[1][1]*m1[2][3]*m1[3][2] - m1[1][2]*m1[2][1]*m1[3][3] )
  341.     end
  342.    
  343.     --// no symbolic matrix supported below here
  344.     local e = m1[1][1]
  345.     local zero  = type(e) == "table" and e.zero or 0
  346.     local norm2 = type(e) == "table" and e.norm2 or number_norm2
  347.  
  348.     --// matrix is bigger than 3x3
  349.     -- get determinant
  350.     -- using Gauss elimination and Laplace
  351.     -- start eliminating from below better for removals
  352.     -- get copy of matrix, set initial determinant
  353.     local mtx = matrix.copy( m1 )
  354.     local det = 1
  355.     -- get det up to the last element
  356.     for j = 1,#mtx[1] do
  357.         -- get smallest element so that |factor| > 1
  358.         -- and set it as last element
  359.         local rows = #mtx
  360.         local subdet,xrow
  361.         for i = 1,rows do
  362.             -- get element
  363.             local e = mtx[i][j]
  364.             -- if no subdet has been found
  365.             if not subdet then
  366.                 -- check if element it is not zero
  367.                 if e ~= zero then
  368.                     -- use element as new subdet
  369.                     subdet,xrow = e,i
  370.                 end
  371.             -- check for elements nearest to 1 or -1
  372.             elseif e ~= zero and math.abs(norm2(e)-1) < math.abs(norm2(subdet)-1) then
  373.                 subdet,xrow = e,i
  374.             end
  375.         end
  376.         -- only cary on if subdet is found
  377.         if subdet then
  378.             -- check if xrow is the last row,
  379.             -- else switch lines and multiply det by -1
  380.             if xrow ~= rows then
  381.                 mtx[rows],mtx[xrow] = mtx[xrow],mtx[rows]
  382.                 det = -det
  383.             end
  384.             -- traverse all fields setting element to zero
  385.             -- we don't set to zero cause we don't use that column anymore then anyways
  386.             for i = 1,rows-1 do
  387.                 -- factor is the dividor of the first element
  388.                 -- if element is not already zero
  389.                 if mtx[i][j] ~= zero then
  390.                     local factor = mtx[i][j]/subdet
  391.                     -- update all remaining fields of the matrix, with value from xrow
  392.                     for n = j+1,#mtx[1] do
  393.                         mtx[i][n] = mtx[i][n] - factor * mtx[rows][n]
  394.                     end
  395.                 end
  396.             end
  397.             -- update determinant and remove row
  398.             if math.fmod( rows,2 ) == 0 then
  399.                 det = -det
  400.             end
  401.             det = det * subdet
  402.             table.remove( mtx )
  403.         else
  404.             -- break here table det is 0
  405.             return det * 0
  406.         end
  407.     end
  408.     -- det ready to return
  409.     return det
  410. end
  411.  
  412. --// matrix.dogauss ( mtx )
  413. -- Gauss elimination, Gauss-Jordan Method
  414. -- this function changes the matrix itself
  415. -- returns on success: true,
  416. -- returns on failure: false,'rank of matrix'
  417.  
  418. -- locals
  419. -- checking here for the element nearest but not equal to zero (smallest pivot element).
  420. -- This way the `factor` in `dogauss` will be >= 1, which
  421. -- can give better results.
  422. local pivotOk = function( mtx,i,j,norm2 )
  423.     -- find min value
  424.     local iMin
  425.     local normMin = math.huge
  426.     for _i = i,#mtx do
  427.         local e = mtx[_i][j]
  428.         local norm = math.abs(norm2(e))
  429.         if norm > 0 and norm < normMin then
  430.             iMin = _i
  431.             normMin = norm
  432.             end
  433.         end
  434.     if iMin then
  435.         -- switch lines if not in position.
  436.         if iMin ~= i then
  437.             mtx[i],mtx[iMin] = mtx[iMin],mtx[i]
  438.         end
  439.         return true
  440.         end
  441.     return false
  442. end
  443.  
  444. local function copy(x)
  445.     return type(x) == "table" and x.copy(x) or x
  446. end
  447.  
  448. -- note: in --// ... //-- we have a way that does no divison,
  449. -- however with big number and matrices we get problems since we do no reducing
  450. function matrix.dogauss( mtx )
  451.     local e = mtx[1][1]
  452.     local zero = type(e) == "table" and e.zero or 0
  453.     local one  = type(e) == "table" and e.one  or 1
  454.     local norm2 = type(e) == "table" and e.norm2 or number_norm2
  455.  
  456.     local rows,columns = #mtx,#mtx[1]
  457.     -- stairs left -> right
  458.     for j = 1,rows do
  459.         -- check if element can be setted to one
  460.         if pivotOk( mtx,j,j,norm2 ) then
  461.             -- start parsing rows
  462.             for i = j+1,rows do
  463.                 -- check if element is not already zero
  464.                 if mtx[i][j] ~= zero then
  465.                     -- we may add x*otherline row, to set element to zero
  466.                     -- tozero - x*mtx[j][j] = 0; x = tozero/mtx[j][j]
  467.                     local factor = mtx[i][j]/mtx[j][j]
  468.                     --// this should not be used although it does no division,
  469.                     -- yet with big matrices (since we do no reducing and other things)
  470.                     -- we get too big numbers
  471.                     --local factor1,factor2 = mtx[i][j],mtx[j][j] //--
  472.                     mtx[i][j] = copy(zero)
  473.                     for _j = j+1,columns do
  474.                         --// mtx[i][_j] = mtx[i][_j] * factor2 - factor1 * mtx[j][_j] //--
  475.                         mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
  476.                     end
  477.                 end
  478.             end
  479.         else
  480.             -- return false and the rank of the matrix
  481.             return false,j-1
  482.         end
  483.     end
  484.     -- stairs right <- left
  485.     for j = rows,1,-1 do
  486.         -- set element to one
  487.         -- do division here
  488.         local div = mtx[j][j]
  489.         for _j = j+1,columns do
  490.             mtx[j][_j] = mtx[j][_j] / div
  491.         end
  492.         -- start parsing rows
  493.         for i = j-1,1,-1 do
  494.             -- check if element is not already zero        
  495.             if mtx[i][j] ~= zero then
  496.                 local factor = mtx[i][j]
  497.                 for _j = j+1,columns do
  498.                     mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
  499.                 end
  500.                 mtx[i][j] = copy(zero)
  501.             end
  502.         end
  503.         mtx[j][j] = copy(one)
  504.     end
  505.     return true
  506. end
  507.  
  508. --// matrix.invert ( m1 )
  509. -- Get the inverted matrix or m1
  510. -- matrix must be square and not singular
  511. -- on success: returns inverted matrix
  512. -- on failure: returns nil,'rank of matrix'
  513. function matrix.invert( m1 )
  514.     assert(#m1 == #m1[1], "matrix not square")
  515.     local mtx = matrix.copy( m1 )
  516.     local ident = setmetatable( {},matrix_meta )
  517.     local e = m1[1][1]
  518.     local zero = type(e) == "table" and e.zero or 0
  519.     local one  = type(e) == "table" and e.one  or 1
  520.     for i = 1,#m1 do
  521.         local identi = {}
  522.         ident[i] = identi
  523.         for j = 1,#m1 do
  524.             identi[j] = copy((i == j) and one or zero)
  525.         end
  526.     end
  527.     mtx = matrix.concath( mtx,ident )
  528.     local done,rank = matrix.dogauss( mtx )
  529.     if done then
  530.         return matrix.subm( mtx, 1,(#mtx[1]/2)+1,#mtx,#mtx[1] )
  531.     else
  532.         return nil,rank
  533.     end
  534. end
  535.  
  536. --// matrix.sqrt ( m1 [,iters] )
  537. -- calculate the square root of a matrix using "Denman Beavers square root iteration"
  538. -- condition: matrix rows == matrix columns; must have a invers matrix and a square root
  539. -- if called without additional arguments, the function finds the first nearest square root to
  540. -- input matrix, there are others but the error between them is very small
  541. -- if called with agument iters, the function will return the matrix by number of iterations
  542. -- the script returns:
  543. --      as first argument, matrix^.5
  544. --      as second argument, matrix^-.5
  545. --      as third argument, the average error between (matrix^.5)^2-inputmatrix
  546. -- you have to determin for yourself if the result is sufficent enough for you
  547. -- local average error
  548. local function get_abs_avg( m1, m2 )
  549.     local dist = 0
  550.     local e = m1[1][1]
  551.     local abs = type(e) == "table" and e.abs or math.abs
  552.     for i=1,#m1 do
  553.         for j=1,#m1[1] do
  554.             dist = dist + abs(m1[i][j]-m2[i][j])
  555.         end
  556.     end
  557.     -- norm by numbers of entries
  558.     return dist/(#m1*2)
  559. end
  560. -- square root function
  561. function matrix.sqrt( m1, iters )
  562.     assert(#m1 == #m1[1], "matrix not square")
  563.     local iters = iters or math.huge
  564.     local y = matrix.copy( m1 )
  565.     local z = matrix(#y, 'I')
  566.     local dist = math.huge
  567.     -- iterate, and get the average error
  568.     for n=1,iters do
  569.         local lasty,lastz = y,z
  570.         -- calc square root
  571.         -- y, z = (1/2)*(y + z^-1), (1/2)*(z + y^-1)
  572.         y, z = matrix.divnum((matrix.add(y,matrix.invert(z))),2),
  573.                 matrix.divnum((matrix.add(z,matrix.invert(y))),2)
  574.         local dist1 = get_abs_avg(y,lasty)
  575.         if iters == math.huge then
  576.             if dist1 >= dist then
  577.                 return lasty,lastz,get_abs_avg(matrix.mul(lasty,lasty),m1)
  578.             end
  579.         end
  580.         dist = dist1
  581.     end
  582.     return y,z,get_abs_avg(matrix.mul(y,y),m1)
  583. end
  584.  
  585. --// matrix.root ( m1, root [,iters] )
  586. -- calculate any root of a matrix
  587. -- source: http://www.dm.unipi.it/~cortona04/slides/bruno.pdf
  588. -- m1 and root have to be given;(m1 = matrix, root = number)
  589. -- conditions same as matrix.sqrt
  590. -- returns same values as matrix.sqrt
  591. function matrix.root( m1, root, iters )
  592.     assert(#m1 == #m1[1], "matrix not square")
  593.     local iters = iters or math.huge
  594.     local mx = matrix.copy( m1 )
  595.     local my = matrix.mul(mx:invert(),mx:pow(root-1))
  596.     local dist = math.huge
  597.     -- iterate, and get the average error
  598.     for n=1,iters do
  599.         local lastx,lasty = mx,my
  600.         -- calc root of matrix
  601.         --mx,my = ((p-1)*mx + my^-1)/p,
  602.         --  ((((p-1)*my + mx^-1)/p)*my^-1)^(p-2) *
  603.         --  ((p-1)*my + mx^-1)/p
  604.         mx,my = mx:mulnum(root-1):add(my:invert()):divnum(root),
  605.             my:mulnum(root-1):add(mx:invert()):divnum(root)
  606.                 :mul(my:invert():pow(root-2)):mul(my:mulnum(root-1)
  607.                 :add(mx:invert())):divnum(root)
  608.         local dist1 = get_abs_avg(mx,lastx)
  609.         if iters == math.huge then
  610.             if dist1 >= dist then
  611.                 return lastx,lasty,get_abs_avg(matrix.pow(lastx,root),m1)
  612.             end
  613.         end
  614.         dist = dist1
  615.     end
  616.     return mx,my,get_abs_avg(matrix.pow(mx,root),m1)
  617. end
  618.  
  619.  
  620. --// Norm functions //--
  621.  
  622. --// matrix.normf ( mtx )
  623. -- calculates the Frobenius norm of the matrix.
  624. --   ||mtx||_F = sqrt(SUM_{i,j} |a_{i,j}|^2)
  625. -- http://en.wikipedia.org/wiki/Frobenius_norm#Frobenius_norm
  626. function matrix.normf(mtx)
  627.     local mtype = matrix.type(mtx)
  628.     local result = 0
  629.     for i = 1,#mtx do
  630.     for j = 1,#mtx[1] do
  631.         local e = mtx[i][j]
  632.         if mtype ~= "number" then e = e:abs() end
  633.         result = result + e^2
  634.     end
  635.     end
  636.     local sqrt = (type(result) == "number") and math.sqrt or result.sqrt
  637.     return sqrt(result)
  638. end
  639.  
  640. --// matrix.normmax ( mtx )
  641. -- calculates the max norm of the matrix.
  642. --   ||mtx||_{max} = max{|a_{i,j}|}
  643. -- Does not work with symbolic matrices
  644. -- http://en.wikipedia.org/wiki/Frobenius_norm#Max_norm
  645. function matrix.normmax(mtx)
  646.     local abs = (matrix.type(mtx) == "number") and math.abs or mtx[1][1].abs
  647.     local result = 0
  648.     for i = 1,#mtx do
  649.     for j = 1,#mtx[1] do
  650.         local e = abs(mtx[i][j])
  651.         if e > result then result = e end
  652.     end
  653.     end
  654.     return result
  655. end
  656.  
  657.  
  658. --// only for number and complex type //--
  659. -- Functions changing the matrix itself
  660.  
  661. --// matrix.round ( mtx [, idp] )
  662. -- perform round on elements
  663. local numround = function( num,mult )
  664.     return math.floor( num * mult + 0.5 ) / mult
  665. end
  666. local tround = function( t,mult )
  667.     for i,v in ipairs(t) do
  668.         t[i] = math.floor( v * mult + 0.5 ) / mult
  669.     end
  670.     return t
  671. end
  672. function matrix.round( mtx, idp )
  673.     local mult = 10^( idp or 0 )
  674.     local fround = matrix.type( mtx ) == "number" and numround or tround
  675.     for i = 1,#mtx do
  676.         for j = 1,#mtx[1] do
  677.             mtx[i][j] = fround(mtx[i][j],mult)
  678.         end
  679.     end
  680.     return mtx
  681. end
  682.  
  683. --// matrix.random( mtx [,start] [, stop] [, idip] )
  684. -- fillmatrix with random values
  685. local numfill = function( _,start,stop,idp )
  686.     return math.random( start,stop ) / idp
  687. end
  688. local tfill = function( t,start,stop,idp )
  689.     for i in ipairs(t) do
  690.         t[i] = math.random( start,stop ) / idp
  691.     end
  692.     return t
  693. end
  694. function matrix.random( mtx,start,stop,idp )
  695.     local start,stop,idp = start or -10,stop or 10,idp or 1
  696.     local ffill = matrix.type( mtx ) == "number" and numfill or tfill
  697.     for i = 1,#mtx do
  698.         for j = 1,#mtx[1] do
  699.             mtx[i][j] = ffill( mtx[i][j], start, stop, idp )
  700.         end
  701.     end
  702.     return mtx
  703. end
  704.  
  705.  
  706. --//////////////////////////////
  707. --// Object Utility Functions //
  708. --//////////////////////////////
  709.  
  710. --// for all types and matrices //--
  711.  
  712. --// matrix.type ( mtx )
  713. -- get type of matrix, normal/complex/symbol or tensor
  714. function matrix.type( mtx )
  715.     local e = mtx[1][1]
  716.     if type(e) == "table" then
  717.         if e.type then
  718.             return e:type()
  719.         end
  720.         return "tensor"
  721.     end
  722.     return "number"
  723. end
  724.    
  725. -- local functions to copy matrix values
  726. local num_copy = function( num )
  727.     return num
  728. end
  729. local t_copy = function( t )
  730.     local newt = setmetatable( {}, getmetatable( t ) )
  731.     for i,v in ipairs( t ) do
  732.         newt[i] = v
  733.     end
  734.     return newt
  735. end
  736.  
  737. --// matrix.copy ( m1 )
  738. -- Copy a matrix
  739. -- simple copy, one can write other functions oneself
  740. function matrix.copy( m1 )
  741.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  742.     local mtx = {}
  743.     for i = 1,#m1[1] do
  744.         mtx[i] = {}
  745.         for j = 1,#m1 do
  746.             mtx[i][j] = docopy( m1[i][j] )
  747.         end
  748.     end
  749.     return setmetatable( mtx, matrix_meta )
  750. end
  751.  
  752. --// matrix.transpose ( m1 )
  753. -- Transpose a matrix
  754. -- switch rows and columns
  755. function matrix.transpose( m1 )
  756.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  757.     local mtx = {}
  758.     for i = 1,#m1[1] do
  759.         mtx[i] = {}
  760.         for j = 1,#m1 do
  761.             mtx[i][j] = docopy( m1[j][i] )
  762.         end
  763.     end
  764.     return setmetatable( mtx, matrix_meta )
  765. end
  766.  
  767. --// matrix.subm ( m1, i1, j1, i2, j2 )
  768. -- Submatrix out of a matrix
  769. -- input: i1,j1,i2,j2
  770. -- i1,j1 are the start element
  771. -- i2,j2 are the end element
  772. -- condition: i1,j1,i2,j2 are elements of the matrix
  773. function matrix.subm( m1,i1,j1,i2,j2 )
  774.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  775.     local mtx = {}
  776.     for i = i1,i2 do
  777.         local _i = i-i1+1
  778.         mtx[_i] = {}
  779.         for j = j1,j2 do
  780.             local _j = j-j1+1
  781.             mtx[_i][_j] = docopy( m1[i][j] )
  782.         end
  783.     end
  784.     return setmetatable( mtx, matrix_meta )
  785. end
  786.  
  787. --// matrix.concath( m1, m2 )
  788. -- Concatenate two matrices, horizontal
  789. -- will return m1m2; rows have to be the same
  790. -- e.g.: #m1 == #m2
  791. function matrix.concath( m1,m2 )
  792.     assert(#m1 == #m2, "matrix size mismatch")
  793.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  794.     local mtx = {}
  795.     local offset = #m1[1]
  796.     for i = 1,#m1 do
  797.         mtx[i] = {}
  798.         for j = 1,offset do
  799.             mtx[i][j] = docopy( m1[i][j] )
  800.         end
  801.         for j = 1,#m2[1] do
  802.             mtx[i][j+offset] = docopy( m2[i][j] )
  803.         end
  804.     end
  805.     return setmetatable( mtx, matrix_meta )
  806. end
  807.  
  808. --// matrix.concatv ( m1, m2 )
  809. -- Concatenate two matrices, vertical
  810. -- will return  m1
  811. --                  m2
  812. -- columns have to be the same; e.g.: #m1[1] == #m2[1]
  813. function matrix.concatv( m1,m2 )
  814.     assert(#m1[1] == #m2[1], "matrix size mismatch")
  815.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  816.     local mtx = {}
  817.     for i = 1,#m1 do
  818.         mtx[i] = {}
  819.         for j = 1,#m1[1] do
  820.             mtx[i][j] = docopy( m1[i][j] )
  821.         end
  822.     end
  823.     local offset = #mtx
  824.     for i = 1,#m2 do
  825.         local _i = i + offset
  826.         mtx[_i] = {}
  827.         for j = 1,#m2[1] do
  828.             mtx[_i][j] = docopy( m2[i][j] )
  829.         end
  830.     end
  831.     return setmetatable( mtx, matrix_meta )
  832. end
  833.  
  834. --// matrix.rotl ( m1 )
  835. -- Rotate Left, 90 degrees
  836. function matrix.rotl( m1 )
  837.     local mtx = matrix:new( #m1[1],#m1 )
  838.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  839.     for i = 1,#m1 do
  840.         for j = 1,#m1[1] do
  841.             mtx[#m1[1]-j+1][i] = docopy( m1[i][j] )
  842.         end
  843.     end
  844.     return mtx
  845. end
  846.  
  847. --// matrix.rotr ( m1 )
  848. -- Rotate Right, 90 degrees
  849. function matrix.rotr( m1 )
  850.     local mtx = matrix:new( #m1[1],#m1 )
  851.     local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
  852.     for i = 1,#m1 do
  853.         for j = 1,#m1[1] do
  854.             mtx[j][#m1-i+1] = docopy( m1[i][j] )
  855.         end
  856.     end
  857.     return mtx
  858. end
  859.  
  860. local function tensor_tostring( t,fstr )
  861.     if not fstr then return "["..table.concat(t,",").."]" end
  862.     local tval = {}
  863.     for i,v in ipairs( t ) do
  864.         tval[i] = string.format( fstr,v )
  865.     end
  866.     return "["..table.concat(tval,",").."]"
  867. end
  868. local function number_tostring( e,fstr )
  869.     return fstr and string.format( fstr,e ) or e
  870. end
  871.  
  872. --// matrix.tostring ( mtx, formatstr )
  873. -- tostring function
  874. function matrix.tostring( mtx, formatstr )
  875.     local ts = {}
  876.     local mtype = matrix.type( mtx )
  877.     local e = mtx[1][1]
  878.     local tostring = mtype == "tensor" and tensor_tostring or
  879.           type(e) == "table" and e.tostring or number_tostring
  880.     for i = 1,#mtx do
  881.         local tstr = {}
  882.         for j = 1,#mtx[1] do
  883.             tstr[j] = tostring(mtx[i][j],formatstr)
  884.         end
  885.         ts[i] = table.concat(tstr, " ")
  886.     end
  887.     return table.concat(ts, " | ")
  888. end
  889.  
  890. --// matrix.print ( mtx [, formatstr] )
  891. -- print out the matrix, just calls tostring
  892. function matrix.print( ... )
  893.     print( matrix.tostring( ... ) )
  894. end
  895.  
  896. --// matrix.latex ( mtx [, align] )
  897. -- LaTeX output
  898. function matrix.latex( mtx, align )
  899.     -- align : option to align the elements
  900.     --      c = center; l = left; r = right
  901.     --      \usepackage{dcolumn}; D{.}{,}{-1}; aligns number by . replaces it with ,
  902.     local align = align or "c"
  903.     local str = "$\\left( \\begin{array}{"..string.rep( align, #mtx[1] ).."}\n"
  904.     local getstr = matrix.type( mtx ) == "tensor" and tensor_tostring or number_tostring
  905.     for i = 1,#mtx do
  906.         str = str.."\t"..getstr(mtx[i][1])
  907.         for j = 2,#mtx[1] do
  908.             str = str.." & "..getstr(mtx[i][j])
  909.         end
  910.         -- close line
  911.         if i == #mtx then
  912.             str = str.."\n"
  913.         else
  914.             str = str.." \\\\\n"
  915.         end
  916.     end
  917.     return str.."\\end{array} \\right)$"
  918. end
  919.  
  920.  
  921. --// Functions not changing the matrix
  922.  
  923. --// matrix.rows ( mtx )
  924. -- return number of rows
  925. function matrix.rows( mtx )
  926.     return #mtx
  927. end
  928.  
  929. --// matrix.columns ( mtx )
  930. -- return number of columns
  931. function matrix.columns( mtx )
  932.     return #mtx[1]
  933. end
  934.  
  935. --//  matrix.size ( mtx )
  936. -- get matrix size as string rows,columns
  937. function matrix.size( mtx )
  938.     if matrix.type( mtx ) == "tensor" then
  939.         return #mtx,#mtx[1],#mtx[1][1]
  940.     end
  941.     return #mtx,#mtx[1]
  942. end
  943.  
  944. --// matrix.getelement ( mtx, i, j )
  945. -- return specific element ( row,column )
  946. -- returns element on success and nil on failure
  947. function matrix.getelement( mtx,i,j )
  948.     if mtx[i] and mtx[i][j] then
  949.         return mtx[i][j]
  950.     end
  951. end
  952.  
  953. --// matrix.setelement( mtx, i, j, value )
  954. -- set an element ( i, j, value )
  955. -- returns 1 on success and nil on failure
  956. function matrix.setelement( mtx,i,j,value )
  957.     if matrix.getelement( mtx,i,j ) then
  958.         -- check if value type is number
  959.         mtx[i][j] = value
  960.         return 1
  961.     end
  962. end
  963.  
  964. --// matrix.ipairs ( mtx )
  965. -- iteration, same for complex
  966. function matrix.ipairs( mtx )
  967.     local i,j,rows,columns = 1,0,#mtx,#mtx[1]
  968.     local function iter()
  969.         j = j + 1
  970.         if j > columns then -- return first element from next row
  971.             i,j = i + 1,1
  972.         end
  973.         if i <= rows then
  974.             return i,j
  975.         end
  976.     end
  977.     return iter
  978. end
  979.  
  980. --///////////////////////////////
  981. --// matrix 'vector' functions //
  982. --///////////////////////////////
  983.  
  984. -- a vector is defined as a 3x1 matrix
  985. -- get a vector; vec = matrix{{ 1,2,3 }}^'T'
  986.  
  987. --// matrix.scalar ( m1, m2 )
  988. -- returns the Scalar Product of two 3x1 matrices (vectors)
  989. function matrix.scalar( m1, m2 )
  990.     return m1[1][1]*m2[1][1] + m1[2][1]*m2[2][1] +  m1[3][1]*m2[3][1]
  991. end
  992.  
  993. --// matrix.cross ( m1, m2 )
  994. -- returns the Cross Product of two 3x1 matrices (vectors)
  995. function matrix.cross( m1, m2 )
  996.     local mtx = {}
  997.     mtx[1] = { m1[2][1]*m2[3][1] - m1[3][1]*m2[2][1] }
  998.     mtx[2] = { m1[3][1]*m2[1][1] - m1[1][1]*m2[3][1] }
  999.     mtx[3] = { m1[1][1]*m2[2][1] - m1[2][1]*m2[1][1] }
  1000.     return setmetatable( mtx, matrix_meta )
  1001. end
  1002.  
  1003. --// matrix.len ( m1 )
  1004. -- returns the Length of a 3x1 matrix (vector)
  1005. function matrix.len( m1 )
  1006.     return math.sqrt( m1[1][1]^2 + m1[2][1]^2 + m1[3][1]^2 )
  1007. end
  1008.  
  1009.  
  1010. --// matrix.replace (mtx, func, ...)
  1011. -- for each element e in the matrix mtx, replace it with func(mtx, ...).
  1012. function matrix.replace( m1, func, ... )
  1013.     local mtx = {}
  1014.     for i = 1,#m1 do
  1015.         local m1i = m1[i]
  1016.         local mtxi = {}
  1017.         for j = 1,#m1i do
  1018.             mtxi[j] = func( m1i[j], ... )
  1019.         end
  1020.         mtx[i] = mtxi
  1021.     end
  1022.     return setmetatable( mtx, matrix_meta )
  1023. end
  1024.  
  1025. --// matrix.remcomplex ( mtx )
  1026. -- set the matrix elements to strings
  1027. -- IMPROVE: tostring v.s. tostringelements confusing
  1028. function matrix.elementstostrings( mtx )
  1029.     local e = mtx[1][1]
  1030.     local tostring = type(e) == "table" and e.tostring or tostring
  1031.     return matrix.replace(mtx, tostring)
  1032. end
  1033.  
  1034. --// matrix.solve ( m1 )
  1035. -- solve; tries to solve a symbolic matrix to a number
  1036. function matrix.solve( m1 )
  1037.     assert( matrix.type( m1 ) == "symbol", "matrix not of type 'symbol'" )
  1038.     local mtx = {}
  1039.     for i = 1,#m1 do
  1040.         mtx[i] = {}
  1041.         for j = 1,#m1[1] do
  1042.             mtx[i][j] = tonumber( loadstring( "return "..m1[i][j][1] )() )
  1043.         end
  1044.     end
  1045.     return setmetatable( mtx, matrix_meta )
  1046. end
  1047.  
  1048. --////////////////////////--
  1049. --// METATABLE HANDLING //--
  1050. --////////////////////////--
  1051.  
  1052. --// MetaTable
  1053. -- as we declaired on top of the page
  1054. -- local/shared metatable
  1055. -- matrix_meta
  1056.  
  1057. -- note '...' is always faster than 'arg1,arg2,...' if it can be used
  1058.  
  1059. -- Set add "+" behaviour
  1060. matrix_meta.__add = function( ... )
  1061.     return matrix.add( ... )
  1062. end
  1063.  
  1064. -- Set subtract "-" behaviour
  1065. matrix_meta.__sub = function( ... )
  1066.     return matrix.sub( ... )
  1067. end
  1068.  
  1069. -- Set multiply "*" behaviour
  1070. matrix_meta.__mul = function( m1,m2 )
  1071.     if getmetatable( m1 ) ~= matrix_meta then
  1072.         return matrix.mulnum( m2,m1 )
  1073.     elseif getmetatable( m2 ) ~= matrix_meta then
  1074.         return matrix.mulnum( m1,m2 )
  1075.     end
  1076.     return matrix.mul( m1,m2 )
  1077. end
  1078.  
  1079. -- Set division "/" behaviour
  1080. matrix_meta.__div = function( m1,m2 )
  1081.     if getmetatable( m1 ) ~= matrix_meta then
  1082.         return matrix.mulnum( matrix.invert(m2),m1 )
  1083.     elseif getmetatable( m2 ) ~= matrix_meta then
  1084.         return matrix.divnum( m1,m2 )
  1085.     end
  1086.     return matrix.div( m1,m2 )
  1087. end
  1088.  
  1089. -- Set unary minus "-" behavior
  1090. matrix_meta.__unm = function( mtx )
  1091.     return matrix.mulnum( mtx,-1 )
  1092. end
  1093.  
  1094. -- Set power "^" behaviour
  1095. -- if opt is any integer number will do mtx^opt
  1096. --   (returning nil if answer doesn't exist)
  1097. -- if opt is 'T' then it will return the transpose matrix
  1098. -- only for complex:
  1099. --    if opt is '*' then it returns the complex conjugate matrix
  1100.     local option = {
  1101.         -- only for complex
  1102.         ["*"] = function( m1 ) return matrix.conjugate( m1 ) end,
  1103.         -- for both
  1104.         ["T"] = function( m1 ) return matrix.transpose( m1 ) end,
  1105.     }
  1106. matrix_meta.__pow = function( m1, opt )
  1107.     return option[opt] and option[opt]( m1 ) or matrix.pow( m1,opt )
  1108. end
  1109.  
  1110. -- Set equal "==" behaviour
  1111. matrix_meta.__eq = function( m1, m2 )
  1112.     -- check same type
  1113.     if matrix.type( m1 ) ~= matrix.type( m2 ) then
  1114.         return false
  1115.     end
  1116.     -- check same size
  1117.     if #m1 ~= #m2 or #m1[1] ~= #m2[1] then
  1118.         return false
  1119.     end
  1120.     -- check elements equal
  1121.     for i = 1,#m1 do
  1122.         for j = 1,#m1[1] do
  1123.             if m1[i][j] ~= m2[i][j] then
  1124.                 return false
  1125.             end
  1126.         end
  1127.     end
  1128.     return true
  1129. end
  1130.  
  1131. -- Set tostring "tostring( mtx )" behaviour
  1132. matrix_meta.__tostring = function( ... )
  1133.     return matrix.tostring( ... )
  1134. end
  1135.  
  1136. -- set __call "mtx( [formatstr] )" behaviour, mtx [, formatstr]
  1137. matrix_meta.__call = function( ... )
  1138.     matrix.print( ... )
  1139. end
  1140.  
  1141. --// __index handling
  1142. matrix_meta.__index = {}
  1143. for k,v in pairs( matrix ) do
  1144.     matrix_meta.__index[k] = v
  1145. end
  1146.  
  1147.  
  1148. --/////////////////////////////////
  1149. --// symbol class implementation
  1150. --/////////////////////////////////
  1151.  
  1152. -- access to the symbolic metatable
  1153. local symbol_meta = {}; symbol_meta.__index = symbol_meta
  1154. local symbol = symbol_meta
  1155.  
  1156. function symbol_meta.new(o)
  1157.     return setmetatable({tostring(o)}, symbol_meta)
  1158. end
  1159. symbol_meta.to = symbol_meta.new
  1160.  
  1161. -- symbol( arg )
  1162. -- same as symbol.to( arg )
  1163. -- set __call behaviour of symbol
  1164. setmetatable( symbol_meta, { __call = function( _,s ) return symbol_meta.to( s ) end } )
  1165.  
  1166.  
  1167. -- Converts object to string, optionally with formatting.
  1168. function symbol_meta.tostring( e,fstr )
  1169.     return string.format( fstr,e[1] )
  1170. end
  1171.  
  1172. -- Returns "symbol" if object is a symbol type, else nothing.
  1173. function symbol_meta:type()
  1174.     if getmetatable(self) == symbol_meta then
  1175.         return "symbol"
  1176.     end
  1177. end
  1178.  
  1179. -- Performs string.gsub on symbol.
  1180. -- for use in matrix.replace
  1181. function symbol_meta:gsub(from, to)
  1182.     return symbol.to( string.gsub( self[1],from,to ) )
  1183. end
  1184.  
  1185. -- creates function that replaces one letter by something else
  1186. -- makereplacer( "a",4,"b",7, ... )(x)
  1187. -- will replace a with 4 and b with 7 in symbol x.
  1188. -- for use in matrix.replace
  1189. function symbol_meta.makereplacer( ... )
  1190.     local tosub = {}
  1191.     local args = {...}
  1192.     for i = 1,#args,2 do
  1193.         tosub[args[i]] = args[i+1]
  1194.     end
  1195.     local function func( a ) return tosub[a] or a end
  1196.     return function(sym)
  1197.         return symbol.to( string.gsub( sym[1], "%a", func ) )
  1198.     end
  1199. end
  1200.  
  1201. -- applies abs function to symbol
  1202. function symbol_meta.abs(a)
  1203.     return symbol.to("(" .. a[1] .. "):abs()")
  1204. end
  1205.  
  1206. -- applies sqrt function to symbol
  1207. function symbol_meta.sqrt(a)
  1208.     return symbol.to("(" .. a[1] .. "):sqrt()")
  1209. end
  1210.  
  1211. function symbol_meta.__add(a,b)
  1212.     return symbol.to(a .. "+" .. b)
  1213. end
  1214.  
  1215. function symbol_meta.__sub(a,b)
  1216.     return symbol.to(a .. "-" .. b)
  1217. end
  1218.  
  1219. function symbol_meta.__mul(a,b)
  1220.     return symbol.to("(" .. a .. ")*(" .. b .. ")")
  1221. end
  1222.  
  1223. function symbol_meta.__div(a,b)
  1224.     return symbol.to("(" .. a .. ")/(" .. b .. ")")
  1225. end
  1226.  
  1227. function symbol_meta.__pow(a,b)
  1228.     return symbol.to("(" .. a .. ")^(" .. b .. ")")
  1229. end
  1230.  
  1231. function symbol_meta.__eq(a,b)
  1232.     return a[1] == b[1]
  1233. end
  1234.  
  1235. function symbol_meta.__tostring(a)
  1236.     return a[1]
  1237. end
  1238.  
  1239. function symbol_meta.__concat(a,b)
  1240.     return tostring(a) .. tostring(b)
  1241. end
  1242.  
  1243. matrix.symbol = symbol
  1244.  
  1245.  
  1246. -- return matrix
  1247. return matrix
  1248.  
  1249. --///////////////--
  1250. --// chillcode //--
  1251. --///////////////--
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement