Advertisement
EelcoHoogendoorn

PYCUDA ndarray kernel musings

Aug 27th, 2012
158
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.03 KB | None | 0 0
  1.  
  2. """
  3. some conceptual code to illustrate planned kernel syntax stuff
  4. this is just a small stub of code, for ease of mental access for others who want to look at it or play with it
  5. it is currently not compatible with my other messy code-base, but that will come
  6.  
  7.  
  8. the intent of this code is to extend elementwise-type pycuda functionality.
  9. specifically, i aim for a kernel factory that is ndarray-aware. this is taken to mean several things:
  10.    runtime type and shape checking of input arrays
  11.    macros to obtain shapes and strides of arrays within the kernel; both the compile time and runtime constants
  12.    macros to obtain the current index being processed; nd generalization of the 'i' in elementwise
  13.    or even better: arbitrary dereferencing of arrays ala weave.blitz:
  14.        arrayidentifier(x,y,z) -> transformed into an appropriate inner product over strides
  15.  
  16.  
  17. this is all fairly easy to do for a fixed number of dimension (what my main code base is built on);
  18. but the obvious next step is to emit code that works for any dimension
  19.  
  20. in order to do so, we should have some clear logic to decide which axes will
  21. be looped over (serial), and which axes will be handled by the thread launcher (parallel)
  22.  
  23. we should always at least launch 1d kernels to have a minimum of parralelism;
  24. furthermore, the stride=1 axis should always be parralel for coalesing
  25. but exactly how many axes we should do parralel is highly problem (and somewhat hardware) dependent
  26. f.i., 2d deconvolution is typically best done by reading coalesed rows into shared mem,
  27. and then looping over the other axis
  28. a 3d kernel on older hardware must be force to loop over the biggest stride, and so on
  29.  
  30. anyway, having decided upon a partition of parralel and serial axes:
  31.    parralel axes should emit indexing code of the form:
  32.        unsigned i0 = blockDim.x*blockIdx.x + threadIdx.x;
  33.    and serial axes should (recursively) emit an outer loop of the form:
  34.        unsigned in; for in = 0; in<kernel.shape[n]; in++) { $body}
  35.  
  36. the most basis implementation that always works (though not optimally);
  37. simply always launch 1d threadblocks, and loop over the other axes
  38.  
  39. we should probably allow for a manual override of axis paralelism, considering the problem-dependence
  40.  
  41.  
  42. another low priority thing, but which would be cool to have:
  43.    can we make a general shared mem mechanism? upon marking an input arg with a keyword, emit code
  44.    to defer reads of that array to a block of shared mem, with correct padding as decided by the supplied stencil
  45.    shared mem only really comes into play for stencil operations. a simple nd-aware stencil like an outer product
  46.    does not stand to gain much from shared mem; it is still 'elementwise', and needs NxM writes anyway, so there is not much
  47.    point trying to optimize the technically unnecessary NxM reads to N+M
  48.  
  49.  
  50. """
  51.  
  52.  
  53.  
  54.  
  55. import numpy as np
  56.  
  57. from pyparsing import *
  58.  
  59. import pycuda.compyte.dtypes
  60. pycuda.compyte.dtypes._fill_dtype_registry(True)
  61. dtype_to_ctype = {k:v for k,v in pycuda.compyte.dtypes.DTYPE_TO_NAME.iteritems() if isinstance(k, str)}
  62.  
  63.  
  64. def arg_parse(arguments):
  65.  
  66.     """
  67.    take a raw argument string of the specified syntax, and transform it
  68.    to an ordereddict of identifier:(dtype, shape), for further processing
  69.    """
  70.     np_types = ' '.join(dtype_to_ctype.keys())
  71.     dtype = oneOf(np_types).setResultsName('dtype')
  72.  
  73.     identifier = (Word(alphanums+'_')).setResultsName('identifier') #not fully correct; also matches names starting with a numeral. whatever
  74.  
  75.     positive_integer = Word(nums)
  76.     dimension = Or([positive_integer, Literal(':')])
  77.     shape = nestedExpr('[',']', delimitedList(dimension)).setResultsName('shape')
  78.  
  79.     term = Group( dtype + Optional(shape) + identifier)
  80.     grammar = delimitedList(term)
  81.  
  82.     #convert parsing result to typeinfo object, for reasonably efficient runtime checks, and clean downstream code
  83.     def shape_parse(shape):     #apply parseactions to shape args
  84.         if shape is '': return None     #scalar argument
  85.         shape = tuple(None if size is ':' else int(size) for size in shape[0])
  86.         if len(shape)==0: raise Exception('arrays must have at least one dimension')
  87.         return shape
  88.  
  89.     from collections import OrderedDict
  90.     arg_info = OrderedDict()
  91.     for argument in grammar.parseString(arguments):
  92.         arg_info[argument.identifier] = argument.dtype, shape_parse(argument.shape)
  93.  
  94.     return arg_info
  95.  
  96.  
  97. def build_argument_strings(arg_info):
  98.     """
  99.    build c code to represent all the compile time and run time info
  100.    that we seek to pass into the kernel
  101.    """
  102.     #add restrict keyword to ptrs by default; my preferred syntax would be to add an ALIAS keyword to supress this
  103.     c_arguments = ', '.join(
  104.         '{restrict}{type}{ptr} {identifier}'.format(type=dtype_to_ctype[dtype], ptr='' if shape is None else '*', restrict='' if shape is None else '__restrict__ ', identifier=identifier)
  105.             for identifier,(dtype, shape) in arg_info.iteritems()
  106.     )
  107.  
  108.     variable_shape_arguments = ', '.join(
  109.         'constant unsigned {identifier}_shape_{dimension}'.format(identifier=identifier, dimension=dimension)
  110.             for identifier,(dtype, shape) in arg_info.iteritems() if not shape is None
  111.                 for dimension, size in enumerate( shape) if size is None
  112.     )
  113.  
  114.     constant_shape_arguments = '\n'.join(
  115.         'constant unsigned {identifier}_shape_{dimension} = {size};'.format(identifier=identifier, dimension=dimension, size=size)
  116.            for identifier,(dtype, shape) in arg_info.iteritems() if not shape is None
  117.                for dimension, size in enumerate( shape) if not size is None
  118.     )
  119.  
  120.     return c_arguments, variable_shape_arguments, constant_shape_arguments
  121.  
  122. def compile_kernel_shape(shape):
  123.     """
  124.    emit code that determines the shape of the virtual ndarray over which we evaluate the kernel
  125.    this function hardcodes this into the kernel
  126.    """
  127.     return '\n'.join(
  128.         'constant unsigned kernel_shape_{dimension} = {size};'.format(dimension=dimension, size=size)
  129.             for dimension,size in enumerate( shape)
  130.         )
  131. def runtime_kernel_shape(shape):
  132.     """
  133.    emit code that determines the shape of the virtual ndarray over which we evaluate the kernel
  134.    emits extra arguments to be supplied upon calling of the kernel
  135.    """
  136.     raise NotImplementedError()
  137.  
  138.  
  139. def compute_strides(arg_info):
  140.     """
  141.    generate code to compute the strides of each array
  142.    I am assuming in all this the CUDA compiler has a clue about optimizing constant expressions
  143.    otherwise, might have to use C macros for this kind of thing?
  144.    I do hope unused constants get eliminated, and so on?
  145.    """
  146.     def strides(identifier, shape):
  147.         """
  148.        inner generator, to generate a sequence of stride definitions for each array
  149.        i havnt actually settled on a convention for what is the first and last axis
  150.        not having such a convention, it isnt applied either. in other words, this code
  151.        is for illustrative purposes only
  152.        """
  153.         stride_template  = '{identifier}_stride_{dimension}'
  154.         shape_template = '{identifier}_shape_{dimension}'
  155.         prev = stride_template.format(identifier=identifier, dimension=len(shape)-1)
  156.         yield 'constant unsigned {identifier} = {stride};'.format(identifier = prev ,stride = 1)
  157.         for i, size in reversed(list(enumerate( shape[:-1]))):
  158.             this = stride_template.format(identifier=identifier, dimension=i)
  159.             size = shape_template.format(identifier=identifier, dimension=i+1)
  160.             yield 'constant unsigned {this} = {prev} * {size};'.format(this=this, prev=prev ,size = size)
  161.             prev = this
  162.         #add total element size as well, for good measure
  163.         size = shape_template.format(identifier=identifier, dimension=0)
  164.         yield 'constant unsigned {identifier}_size = {prev} * {size};'.format(identifier = identifier, prev=prev ,size=size)
  165.  
  166.     return '\n'.join(
  167.         stride_expr
  168.             for identifier,(dtype, shape) in arg_info.iteritems() if not shape is None
  169.                 for stride_expr in strides(identifier, shape))
  170.  
  171.  
  172. def replace_typing(source):
  173.     """
  174.    replace numpy types with c-types. this could be more efficient and intelligent...
  175.    we do not do any semantic analysis here; simple find and replace
  176.    but useage is optional anyway; we are fully backwards compatible, free to use ctypes in our code
  177.    """
  178.     np_types = ' '.join(dtype_to_ctype.keys())
  179.     type_grammar = oneOf(np_types)
  180.     type_grammar.setParseAction(lambda s,l,t: dtype_to_ctype[t[0]])
  181.     return type_grammar.transformString(source)
  182.  
  183.  
  184. def replace_shape_syntax(source, arg_info):
  185.     """
  186.    replace arrayidentifier.shape[ndim] syntax with C named variables
  187.    silently fails to replace some wrong syntax, like misspelled shape;
  188.    dont worry, the cuda compiler is sure to complain about it :)
  189.    would it be sufficient and currect to catch all instances of 'arrayidentifier.'+whatever,
  190.    that fail to match the whole syntax?
  191.    """
  192.     arrayidentifier = (Word(alphanums+'_')).setResultsName('identifier') # + Optional( Word(alphanums))
  193.     positive_integer = Word(nums)
  194.     shape_expr = arrayidentifier + Suppress( Literal('.shape')) + nestedExpr('[',']', positive_integer).setResultsName('dimension')
  195.  
  196.     def replace(s,l,t):    #string, locaction, parseresults
  197.         """if match is correct, replace numpy syntax with c-compatible syntax"""
  198.         identifier = t.identifier
  199.         dimensions = t.dimension[0]
  200.         if not len(dimensions)==1: raise Exception('only simple shape indexing allows')
  201.         dimension = dimensions[0]
  202.         try:
  203.             dtype, shape = arg_info[identifier]
  204.         except KeyError:
  205.             raise ParseFatalException("array '{identifier}' is not defined".format(identifier=identifier))
  206.         try:
  207.             size = shape[int(dimension)]
  208.         except Exception:
  209.             raise ParseFatalException('{identifier}.shape[{dimension}] is invalid'.format(identifier=identifier, dimension=dimension))
  210.  
  211.         return '{identifier}_shape_{dimension}'.format(identifier=identifier, dimension=dimension)
  212.     shape_expr.setParseAction(replace)
  213.  
  214.     return shape_expr.transformString(source)
  215.  
  216.  
  217. def replace_array_syntax(source, arg_info):
  218.     """
  219.    replace weave.blitz style array indexing with inner product over strides
  220.    we could optionally insert bounds checking code here as well, as a debugging aid
  221.    """
  222.     arrayidentifier = oneOf(' '.join(arg_info.keys())).setResultsName('identifier')
  223. ##    arrayidentifier = (Word(alphanums+'_')).setResultsName('identifier') # + Optional( Word(alphanums))
  224.     identifier = Word(alphanums+'_')
  225.     positive_integer = Word(nums)
  226.     index = Or([identifier, positive_integer])
  227.     index_expr = arrayidentifier + nestedExpr('(',')', delimitedList( index)).setResultsName('indices')
  228.  
  229.     def replace(s,l,t):    #string, locaction, parseresults
  230.         """if match is correct, replace numpy syntax with c-compatible syntax"""
  231.         identifier = t.identifier
  232.         indices = t.indices[0]
  233.  
  234.         try:
  235.             dtype, shape = arg_info[identifier]
  236.         except KeyError:
  237.             raise ParseFatalException("array '{identifier}' is not defined".format(identifier=identifier))
  238.  
  239.         if not len(indices)==len(shape):
  240.             raise Exception("indexing '{identifier}' requires {ndims} arguments".format(identifier=identifier, ndims=len(shape)))
  241.  
  242.  
  243.         offset = '+'.join(
  244.             '{identifier}_stride_{i}*{idx}'.format(identifier=identifier, i=i, idx=idx)
  245.                 for i,idx in enumerate(indices))
  246.         return '{identifier}[{offset}]'.format(identifier=identifier, offset=offset)
  247.     index_expr.setParseAction(replace)
  248.  
  249.     return index_expr.transformString(source)
  250.  
  251.  
  252.  
  253. def kernel_parsing(arguments, body, shape, axes):
  254.     """
  255.    transform syntactic sugar to valid CUDA-C code
  256.    """
  257.  
  258.     print 'ORIGINAL CODE:'
  259.     print
  260.     print arguments
  261.     print body
  262.     print
  263.  
  264.  
  265.     assert(axes[-1]==True)        #last (stride==1) axis should be parralel. maybe this should be a warning instead, but i cant think of any reason
  266.  
  267.     #simple nd-array indexing. probably full of bugs
  268.     for axis in np.flatnonzero(np.array( axes)==0):  #emit loops for serial axes
  269.         loop = """
  270.        unsigned i{n};
  271.        for (i{n}=0; i{n} < kernel_shape_{n}; i{n}++ )
  272.        """.format(n=axis)
  273.         body = loop + '{\n' + body + '\n}'
  274.  
  275.     for i,axis in enumerate( np.flatnonzero(np.array( axes)==1)):    #do tread/block arithmetic for parralel axes, and map from x,y,(z) to nd-enumeration
  276.         d = {0:'x', 1:'y', 2:'z'}
  277.         try:
  278.             parallel_axes = """
  279.            unsigned i{n};
  280.            i{n} = blockDim.{a}*blockIdx.{a} + threadIdx.{a};
  281.            if (i{n} >= kernel_shape_{n}) return;
  282.            """.format(n=axis, a=d[i])
  283.             body = parallel_axes + body
  284.         except:
  285.             raise Exception('requested more parallel axes than are supported by your hardware!')
  286.     #need some analogous code here to determine threadblock size
  287.  
  288.  
  289.  
  290.     arg_info = arg_parse(arguments)
  291.  
  292.  
  293.     #parsing-based replacements
  294.     body = replace_typing(body)
  295.     body = replace_shape_syntax( body, arg_info )
  296.     body = replace_array_syntax( body, arg_info )
  297.  
  298.  
  299.  
  300.     #macro-based substitutions
  301.     generic_template = """
  302.    __global__ ${funcname}(${arguments})
  303.    {
  304.        ${body}
  305.    }
  306.    """
  307.     from string import Template
  308.     template = Template(generic_template)   #would be nice to have identation-aware template...
  309.  
  310.     c_arguments, variable_shape_arguments, constant_shape_arguments = build_argument_strings(arg_info)
  311.  
  312.     print 'TRANSFORMED CODE:'
  313.     print template.substitute(
  314.         funcname = 'templated_function',
  315.         arguments = c_arguments+', '+variable_shape_arguments,
  316.         body = '\n\n'.join([
  317.                 compile_kernel_shape(shape),
  318.                 constant_shape_arguments,
  319.                 compute_strides(arg_info),
  320.                 body
  321.             ])
  322.         )
  323.  
  324.  
  325. def Kernel(object):
  326.     """wrappers for """
  327.     def __init__(self, arguments, func):
  328.         self.arguments = arguments
  329.         self.kernel = kernel
  330.  
  331.     def __call__(self, *args, **kwargs):
  332.         """
  333.        launch kernel
  334.        we need some logic here to extract runtime variable arguments
  335.        and pass them on in correct order to the compiled kernel
  336.        """
  337.  
  338.         self.kernel(*args, **kwargs)
  339.  
  340. def CheckedKernel(Kernel):
  341.  
  342.     def __call__(self, *args, **kwargs):
  343.         #perform runtime checks
  344.         for k,v in kwargs.iteritems():
  345.             dtype, shape = self.arguments[k]
  346.             assert(v.dtype.name is dtype)           #type check on string representation of dtype; is that safe?
  347.             if shape:
  348.                 assert(len(shape)==len(v.shape))    #assert correct number of dimensions
  349.                 for size, vsize in zip(shape, v.shape):
  350.                     if size: assert(size == vsize)  #check all specified dimensions
  351.             else:
  352.                 assert(v.shape is ())       #scalar numpy type
  353.         super(CheckedKernel, self)(*args, **kwargs)
  354.  
  355.  
  356.  
  357.  
  358. funky_product = kernel_parsing(
  359.     #the (virtual) ndarray to arrange kernel call over; as the example demonstrates, we can not genrally deduce this from arguments. if unspecified, should be determined at runtime
  360.     shape = (100,100),
  361.     #specify parallel and serial axes manually (required for now); last (stride=1) axis is made porallel here, the other serial
  362.     axes = [0, 1],
  363.     #note the use of numpy dtypes, and dimension and type constraints in argument list
  364.     arguments = """float64[100,100] result, uint32[:,100,100] foo, uint32[8,100,100] bar, float32 scalar""",
  365.     #note the numpythonic shape accessing, and weave.blitz style array indexing. and the blissfull absence of boilerplate in general :)
  366.     body = """
  367.    float64 r = 0;
  368.    for (uint32 i; i < foo.shape[0]; i++)
  369.        for (uint32 j; j < bar.shape[0]; j++)
  370.            r += foo(i,i0,i1) * bar(j,i0,i1) * scalar;
  371.    result(i0, i1) = r
  372.    """
  373.     )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement