Guest User

tiny tracer concurrent

a guest
Jan 9th, 2013
147
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.65 KB | None | 0 0
  1. #tiny-tracer, now concurrent but 6s slower
  2. #also known as "I'm doing something wrong"
  3.  
  4. #original pastebin: http://pastebin.com/F8f5GHJZ
  5. #original thread:   http://www.reddit.com/r/tinycode/comments/169ri9/ray_tracer_in_140_sloc_of_python_with_picture/
  6.  
  7. from math import sqrt, pow, pi
  8. from multiprocessing import Pool
  9. import Image, sys
  10.  
  11. class Vector( object ):
  12.    
  13.     def __init__(self,x,y,z):
  14.         self.x = x
  15.         self.y = y
  16.         self.z = z
  17.    
  18.     def dot(self, b):
  19.         return self.x*b.x + self.y*b.y + self.z*b.z
  20.        
  21.     def cross(self, b):
  22.         return (self.y*b.z-self.z*b.y, self.z*b.x-self.x*b.z, self.x*b.y-self.y*b.x)
  23.        
  24.     def magnitude(self):
  25.         return sqrt(self.x**2+self.y**2+self.z**2)
  26.        
  27.     def normal(self):
  28.         mag = self.magnitude()
  29.         return Vector(self.x/mag,self.y/mag,self.z/mag)
  30.        
  31.     def __add__(self, b):
  32.         return Vector(self.x + b.x, self.y+b.y, self.z+b.z)
  33.    
  34.     def __sub__(self, b):
  35.         return Vector(self.x-b.x, self.y-b.y, self.z-b.z)
  36.        
  37.     def __mul__(self, b):
  38.         assert type(b) == float or type(b) == int
  39.         return Vector(self.x*b, self.y*b, self.z*b)    
  40.    
  41. class Sphere( object ):
  42.    
  43.     def __init__(self, center, radius, color):
  44.         self.c = center
  45.         self.r = radius
  46.         self.col = color
  47.        
  48.     def intersection(self, l):
  49.         q = l.d.dot(l.o - self.c)**2 - (l.o - self.c).dot(l.o - self.c) + self.r**2
  50.         if q < 0:
  51.             return Intersection( Vector(0,0,0), -1, Vector(0,0,0), self)
  52.         else:
  53.             d = -l.d.dot(l.o - self.c)
  54.             d1 = d - sqrt(q)
  55.             d2 = d + sqrt(q)
  56.             if 0 < d1 and ( d1 < d2 or d2 < 0):
  57.                 return Intersection(l.o+l.d*d1, d1, self.normal(l.o+l.d*d1), self)
  58.             elif 0 < d2 and ( d2 < d1 or d1 < 0):
  59.                 return Intersection(l.o+l.d*d2, d2, self.normal(l.o+l.d*d2), self)
  60.             else:
  61.                 return Intersection( Vector(0,0,0), -1, Vector(0,0,0), self)   
  62.            
  63.     def normal(self, b):
  64.         return (b - self.c).normal()
  65.        
  66. class Plane( object ):
  67.    
  68.     def __init__(self, point, normal, color):
  69.         self.n = normal
  70.         self.p = point
  71.         self.col = color
  72.        
  73.     def intersection(self, l):
  74.         d = l.d.dot(self.n)
  75.         if d == 0:
  76.             return Intersection( vector(0,0,0), -1, vector(0,0,0), self)
  77.         else:
  78.             d = (self.p - l.o).dot(self.n) / d
  79.             return Intersection(l.o+l.d*d, d, self.n, self)
  80.        
  81. class Ray( object ):
  82.    
  83.     def __init__(self, origin, direction):
  84.         self.o = origin
  85.         self.d = direction
  86.        
  87. class Intersection( object ):
  88.    
  89.     def __init__(self, point, distance, normal, obj):
  90.         self.p = point
  91.         self.d = distance
  92.         self.n = normal
  93.         self.obj = obj
  94.        
  95. def testRay(ray, objects, ignore=None):
  96.     intersect = Intersection( Vector(0,0,0), -1, Vector(0,0,0), None)
  97.    
  98.     for obj in objects:
  99.         if obj is not ignore:
  100.             currentIntersect = obj.intersection(ray)
  101.             if currentIntersect.d > 0 and intersect.d < 0:
  102.                 intersect = currentIntersect
  103.             elif 0 < currentIntersect.d < intersect.d:
  104.                 intersect = currentIntersect
  105.     return intersect
  106.    
  107. def trace(ray, objects, light, maxRecur):
  108.     if maxRecur < 0:
  109.         return (0,0,0)
  110.     intersect = testRay(ray, objects)      
  111.     if intersect.d == -1:
  112.         col = vector(AMBIENT,AMBIENT,AMBIENT)
  113.     elif intersect.n.dot(light - intersect.p) < 0:
  114.         col = intersect.obj.col * AMBIENT
  115.     else:
  116.         lightRay = Ray(intersect.p, (light-intersect.p).normal())
  117.         if testRay(lightRay, objects, intersect.obj).d == -1:
  118.             lightIntensity = 1000.0/(4*pi*(light-intersect.p).magnitude()**2)
  119.             col = intersect.obj.col * max(intersect.n.normal().dot((light - intersect.p).normal()*lightIntensity), AMBIENT)
  120.         else:
  121.             col = intersect.obj.col * AMBIENT
  122.     return col
  123.    
  124. def gammaCorrection(color,factor):
  125.     return (int(pow(color.x/255.0,factor)*255),
  126.             int(pow(color.y/255.0,factor)*255),
  127.             int(pow(color.z/255.0,factor)*255))
  128.  
  129. AMBIENT = 0.1
  130. GAMMA_CORRECTION = 1/2.2
  131. MAX_RECURSION = 10
  132.  
  133. objs = [] #these are left global so we don't have to pass them to apply_async each pixel
  134. objs.append(Sphere( Vector(-2,0,-10), 2, Vector(0,255,0)))
  135. objs.append(Sphere( Vector(2,0,-10), 3.5, Vector(255,0,0)))
  136. objs.append(Sphere( Vector(0,-4,-10), 3, Vector(0,0,255)))
  137. objs.append(Plane( Vector(0,0,-12), Vector(0,0,1), Vector(255,255,255)))
  138. lightSource = Vector(0,10,0)
  139. cameraPos = Vector(0,0,20)
  140.  
  141. def getPixel(x, y):
  142.     ray =  Ray( cameraPos, (Vector(x/50.0-5,y/50.0-5,0)-cameraPos).normal())
  143.     col = trace(ray, objs, lightSource, MAX_RECURSION)
  144.     return gammaCorrection(col,GAMMA_CORRECTION)
  145.  
  146. def main():
  147.     img = Image.new("RGB",(500,500))
  148.     pool = Pool(processes=2)
  149.     for x in range(500):
  150.         results = []
  151.         for y in range(500):
  152.             result = pool.apply_async(getPixel, [x,y])
  153.             results.append(result)
  154.         for y in range(500):
  155.             pix = results[y].get()
  156.             img.putpixel((x,499-y),pix)
  157.         print x
  158.         sys.stdout.flush()
  159.     img.save("trace.bmp","BMP")
  160.  
  161. if __name__ == '__main__':
  162.     main()
Add Comment
Please, Sign In to add comment