Guest

otavio

By: a guest on Dec 11th, 2008  |  syntax: Python  |  size: 3.86 KB  |  hits: 135  |  expires: Never
download  |  raw  |  embed  |  report abuse
Copied
  1. import sys
  2.  
  3. #############
  4. # Buggy Gauss Elimination
  5. # (Zero division is not checked)
  6. #
  7.  
  8. def det(rows):
  9.     v = None
  10.    
  11.     if len(rows) == 2:
  12.         r1 = rows[0]
  13.         r2 = rows[1]
  14.         v = r1[0] * r2[1] - r1[1] * r2[0]
  15.     else:
  16.         firstRow = rows[0]
  17.         aboveRows = rows[1:]        
  18.         subDets = []
  19.  
  20.         # At time I din't know the existence of enumerate
  21.         for c in range(0, len(firstRow)):
  22.             subMatrix = []
  23.             for ar in aboveRows:
  24.                 subRow = []
  25.                 for c2 in range(0, len(ar)):
  26.                     if c != c2:
  27.                         subRow.append(ar[c2])
  28.                 subMatrix.append(subRow)
  29.             subDets.append(det(subMatrix) * firstRow[c])
  30.            
  31.         evens = [subDets[e] for e in range(0, len(subDets), 2)]
  32.         odds = [subDets[e] for e in range(1, len(subDets), 2)]
  33.  
  34.         v = reduce(lambda x, y: x+y, evens) - reduce(lambda x, y: x+y, odds)        
  35.  
  36.     return v
  37.            
  38. #non-recursive                
  39. def solveSystem(rows):
  40.     if det(rows) == 0:
  41.         return None
  42.    
  43.     zerorColumnNth = 0
  44.     solvedSystem = rows
  45.    
  46.     for workRowNth in range(0, len(rows)-1):
  47.         for prodRowNth in range(workRowNth+1, len(rows)):
  48.             workRow = solvedSystem[workRowNth]
  49.             prodRow = solvedSystem[prodRowNth]
  50.            
  51.             mul = -prodRow[zerorColumnNth] / workRow[zerorColumnNth]
  52.             newProdRow = map(lambda a, b: a + b, map(lambda x: x*mul, workRow), prodRow)
  53.             solvedSystem[prodRowNth] = newProdRow
  54.            
  55.         zerorColumnNth += 1
  56.        
  57.     return solvedSystem
  58.  
  59. #recursive  
  60. def solveSystem2(rows):    
  61.     def solveLoop(rows, workNth, prodNth, zeroColumn, nrows):            
  62.         if prodNth < nrows:            
  63.             workRow = rows[workNth]
  64.             prodRow = rows[prodNth]
  65.    
  66.             mul = -prodRow[zeroColumn] / workRow[zeroColumn]
  67.             rows[prodNth] = map(lambda a, b: a + b, map(lambda x: x*mul, workRow), prodRow)
  68.             return solveLoop(rows, workNth, prodNth+1, zeroColumn, nrows)
  69.         else:
  70.             if workNth < nrows - 1:
  71.                 return solveLoop(rows, workNth+1, workNth+2, zeroColumn + 1, nrows)
  72.             else:
  73.                 return rows
  74.        
  75.     return solveLoop(rows, 0, 1, 0, len(rows))
  76.        
  77. def retroSub(rows):    
  78.     def retroLoop(rs, nth, vals):
  79.  
  80.         if nth >= 0:
  81.             row = rs[nth]
  82.             vt = [vals[v]*row[v] for v in range(nth+1, len(row)-1)]                
  83.             if len(vt) > 0:
  84.                 coff = -reduce(lambda a, b: a+b,  vt)
  85.             else:
  86.                 coff = 0
  87.                
  88.             vals[nth] = (coff + row[len(row)-1])/row[nth]
  89.                
  90.             return retroLoop(rs, nth-1, vals)
  91.         return vals
  92.    
  93.     nrows = len(rows)
  94.     return retroLoop(rows, nrows-1, [0 for i in range(0, nrows)])
  95.  
  96. if __name__ == '__main__':
  97.     try:        
  98.         f = open(sys.argv[1], 'r')
  99.     except IOError, msg:
  100.         print(msg)
  101.         sys.exit(1)
  102.        
  103.     rows = []
  104.    
  105.     for line in f:
  106.         if line != '\n':
  107.             atoms = line.split(' ')
  108.             r = []
  109.             for a in atoms:
  110.                 r.append(float(a))
  111.             rows.append(r)
  112.  
  113.     f.close();
  114.    
  115.     maxColumn = 0
  116.     for r in rows:
  117.         cs = len(r)
  118.         if cs > maxColumn:
  119.             maxColumn = cs
  120.    
  121.     for r in rows:
  122.         cs = len(r)
  123.         if cs != maxColumn:
  124.             print("Matrix nao quad")
  125.             sys.exit(1)
  126.  
  127.     ss = solveSystem2(rows)
  128.     r = retroSub(ss)
  129.    
  130.     print(r)
  131.  
  132. ##############
  133. # Sample input file:
  134. #
  135. # 10 2 3 4 5
  136. # 6 17 8 9 10
  137. # 11 12 23 14 15
  138. # 16 17 18 29 20
  139.  
  140. # Should output:
  141. # [0.28248587570621464, 0.24858757062146891, 0.24293785310734467, 0.23728813559322035]