Advertisement
chrisvarns

OpenCL

Apr 4th, 2012
224
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 42.88 KB | None | 0 0
  1. /*
  2.  * BackPropTrainer.java
  3.  *
  4.  * Copyright (C) August Mayer, 2001-2004. All rights reserved.
  5.  * Please consult the Boone LICENSE file for additional rights granted to you.
  6.  *
  7.  * Created on 19. November 2002, 15:06
  8.  */
  9.  
  10. package boone.training;
  11.  
  12. import boone.*;
  13. import boone.io.*;
  14. import boone.util.VarArray;
  15. import java.nio.ByteBuffer;
  16.  
  17. import static org.jocl.CL.*;
  18. import org.jocl.*;
  19.  
  20. /**
  21.  *  Backpropagation trainer, with momentum (optional).
  22.  *
  23.  *  @author August Mayer
  24.  *  @version $Id: BackpropTrainer.java 2028 2010-05-05 08:27:07Z amayer $
  25.  */
  26. public
  27. class BackpropTrainer extends Trainer
  28. {
  29.    
  30.     /** minimum relevant error, default 0 . */
  31.     protected double minError = 0;
  32.            
  33.     /** the momentum. 0 by default, so momentum is turned off by standard. */
  34.     public double momentum = 0.0;
  35.    
  36.     protected cl_program program_iterate, program_errorcalc, program_backprop = null;
  37.     protected cl_kernel kernel_iterate, kernel_errorcalc, kernel_backprop = null;
  38.     protected cl_context clContext = null;
  39.     protected cl_command_queue clCommandQueue = null;
  40.    
  41.     protected cl_mem memObjects[] = null;
  42.    
  43.     //Number of elements = number of layers
  44.     //Element value = number of nodes in that layer
  45.     protected int GPUTickList[] = null;
  46.    
  47.     //Arrays for input and target patterns to network
  48.     protected double GPUInputPatterns[] = null;
  49.     protected double GPUTargetPatterns[] = null;
  50.     protected int GPUInputPatternSize[] = null;
  51.     protected int GPUTargetPatternSize[] = null;
  52.     //Weights for each link
  53.     protected double GPUWeights[] = null;
  54.     protected double GPULastWeightChange[] = null;
  55.     //Number of inputs to each neuron.
  56.     protected int GPUNumInputs[] = null;
  57.     //Activation functions
  58.     protected int GPUActFuncs[] = null;
  59.     //Bias
  60.     protected double GPUBias[] = null;
  61.     protected int GPUUsingBias[] = null;
  62.     protected double GPULastBiasChange[] = null;
  63.     //Is input neuron / External input trigger
  64.     protected int GPUIsInputNeuron[] = null;
  65.     //Is output neuron / error gen trigger?
  66.     protected int GPUIsOutputNeuron[] = null;
  67.     //Current Pattern
  68.     protected int GPUCurrentPattern[] = {0};
  69.     //Current index offset
  70.     protected int GPUIndexOffset[] = {0};
  71.     //Inputs index array
  72.     protected int GPUInputs[] = null;
  73.     //Error
  74.     //protected double GPUError[] = null;
  75.     //Learning rate
  76.     protected double GPULearningRate[] = {this.learnRate};
  77.     //Momentum
  78.     protected double GPUMomentum[] = {this.momentum};
  79.     //Minimum error
  80.     protected double GPUMinError[] = {this.minError};
  81.    
  82.    
  83.     //Maximum number of doubles associated with one neuron, used for stride
  84.     //while indexing for example, inputs to a neuron. i * maxnumfloats
  85.     protected int GPUMaxNumFloats[] = {0};
  86.    
  87.     protected ByteBuffer GPUMapCurrentPattern = null;
  88.     protected ByteBuffer GPUMapIndexOffset = null;
  89.    
  90.    
  91.     public static final String strKernIterate =
  92.         "#pragma OPENCL EXTENSION cl_khr_fp64 : enable                              \n" +
  93.         "__kernel void Neuron(__global const double *inputPatterns,                 \n" +
  94.         "                       __global double *weights,                           \n" +
  95.         "                       __global const int *numInputs,                      \n" +
  96.         "                       __global const int *activation,                     \n" +
  97.         "                       __global const double *bias,                        \n" +
  98.         "                       __global const int *usingBias,                      \n" +
  99.         "                       __global double *values,                            \n" +
  100.         "                       __global const int *maxNumFloats,                   \n" +
  101.         "                       __global const int *patternIndex,                   \n" +
  102.         "                       __global const int *inputPatternSize,               \n" +
  103.         "                       __global const int *indexOffset,                    \n" +
  104.         "                       __global const int *isInputNeuron,                  \n" +
  105.         "                       __global const int *inputs)                         \n" +
  106.         "{                                                                          \n" +
  107.         "   int gid = get_global_id(0);                                             \n" +
  108.         "   double sum = 0.0;                                                       \n" +
  109.         //"   for(int i = 0; i < numInputs[gid+indexOffset[0]]; i++)                  \n" +
  110.         "   for(int i = 0; i < maxNumFloats[0]; i++)                                \n" +
  111.         "   {                                                                       \n" +
  112.         "       if(i < numInputs[gid+indexOffset[0]])                               \n" +
  113.         "          sum += values[inputs[(gid+indexOffset[0]) * maxNumFloats[0] + i]] * \n" +
  114.         "               weights[(gid+indexOffset[0]) * maxNumFloats[0] + i];        \n" +
  115.         "   }                                                                       \n" +
  116.         "   if(usingBias[gid+indexOffset[0]])                                       \n" +
  117.         "       sum += bias[gid+indexOffset[0]];                                    \n" +
  118.         "   if(isInputNeuron[gid+indexOffset[0]])                                   \n" +
  119.         "       sum += inputPatterns[gid+indexOffset[0]+(patternIndex[0] * inputPatternSize[0])];   \n" +
  120.         "   if(activation[gid+indexOffset[0]] == 1)                                 \n" +
  121.         "       sum = 1.0 / (1.0 + exp(-sum));                                      \n" +
  122.         "   values[gid + indexOffset[0]] = sum;                                     \n" +
  123.         "}                                                                          \n"
  124.         ;
  125.     public static final String strKernErrorCalc =
  126.         "#pragma OPENCL EXTENSION cl_khr_fp64 : enable                              \n" +
  127.         "__kernel void Neuron(__global const double *targetPatterns,                \n" +
  128.         "                       __global double *values,                            \n" +
  129.         "                       __global const double *minError,                    \n" +
  130.         "                       __global const int *indexOffset,                    \n" +
  131.         "                       __global const int *patternIndex,                   \n" +
  132.         "                       __global const int *targetPatternSize,              \n" +
  133.         "                       __global double *error)                             \n" +
  134.         "{                                                                          \n" +
  135.         "   int gid = get_global_id(0);                                             \n" +
  136.         "   double errorCalc =                                                      \n" +
  137.         "       targetPatterns[(targetPatternSize[0]*patternIndex[0])+gid+indexOffset[0]]\n" +
  138.         "       - values[gid + indexOffset[0]];                                      \n" +
  139.         "   if(fabs(errorCalc) < minError[0])                                        \n" +
  140.         "       error[gid + indexOffset[0]] = 0;                                    \n" +
  141.         "   else error[gid + indexOffset[0]] = errorCalc;                           \n" +
  142.         "}                                                                          \n"
  143.         ;
  144.     public static final String strKernBackProp =
  145.         "#pragma OPENCL EXTENSION cl_khr_fp64 : enable                              \n" +
  146.         "__kernel void Neuron(__global const int *usingBias,                        \n" +
  147.         "                       __global const int *indexOffset,                    \n" +
  148.         "                       __global const double *learningRate,                \n" +
  149.         "                       __global double *error,                             \n" +
  150.         "                       __global double *values,                            \n" +
  151.         "                       __global const double *momentum,                    \n" +
  152.         "                       __global double *lastBiasChange,                    \n" +
  153.         "                       __global double *lastWeightChange,                  \n" +
  154.         "                       __global const double *numInputs,                   \n" +
  155.         "                       __global double *weights,                           \n" +
  156.         "                       __global const int *maxNumFloats,                \n" +
  157.         "                       __global const int *inputs,                         \n" +
  158.         "                       __global double *bias)                        \n" +
  159.         "{                                                                          \n" +
  160.         "   int gid = get_global_id(0);                                             \n" +
  161.         "   if(usingBias[gid + indexOffset[0]])                                     \n" +
  162.         "   {                                                                       \n" +
  163.         "       double biasChange = learningRate[0]                                 \n" +
  164.         "           * error[gid + indexOffset[0]]                                   \n" +
  165.         "           * (values[gid+indexOffset[0]] * (1.0 - values[gid+indexOffset[0]]))  \n" +
  166.         "           + momentum[0] * lastBiasChange[gid+indexOffset[0]];             \n" +
  167.         "       bias[gid + indexOffset[0]] += biasChange;                           \n" +
  168.         "       lastBiasChange[gid + indexOffset[0]] = biasChange;                  \n" +
  169.         "   }                                                                       \n" +
  170.         "                                                                           \n" +
  171.         "   for(int i = 0; i < numInputs[gid + indexOffset[0]]; i++)                \n" +
  172.         "   {                                                                       \n" +
  173.         "       double ces = error[gid+indexOffset[0]] * (values[gid+indexOffset[0]]\n" +
  174.         "           * (1.0 - values[gid+indexOffset[0]]));                          \n" +
  175.         "       double wes = ces                                                    \n" +
  176.         "           * weights[((gid+indexOffset[0]) * maxNumFloats[0]) + i];        \n" +
  177.         "                                                                           \n" +
  178.         "       error[inputs[((gid+indexOffset[0]) * maxNumFloats[0]) + i]] += wes; \n" +
  179.         "                                                                           \n" +
  180.         "       double wc = learningRate[0] * ces                                   \n" +
  181.         "           * values[inputs[((gid+indexOffset[0]) * maxNumFloats[0]) +i]] + momentum[0]\n" +
  182.         "           * lastWeightChange[((gid+indexOffset[0]) * maxNumFloats[0]) +i];\n" +
  183.         "       weights[((gid+indexOffset[0]) * maxNumFloats[0]) + i] += wc;        \n" +
  184.         "       lastWeightChange[((gid+indexOffset[0]) * maxNumFloats[0]) + i] = wc;\n" +
  185.         "   }                                                                       \n" +
  186.         "}                                                                          \n"
  187.         ;
  188.    
  189.     /*
  190.      * Takes pattern set and number of loops, converts patterns to GPU
  191.      * arrays, works out GPU tick list, sets initial weights, and trains until
  192.      * each input pattern has been done n times. At this point, the weights and
  193.      * error are read back.
  194.      */
  195.     public void GPUTrain(PatternSet pSet, int eachTimes)
  196.     {
  197.         setTraining(true);
  198.        
  199.         if(GPUTickList == null || GPUTickList != net.generateGPUTickList())
  200.         {
  201.             //Setup tick list
  202.             GPUTickList = net.generateGPUTickList();
  203.             int biggestLoop = 0;
  204.             for(int i = 0; i < GPUTickList.length; i++)
  205.                 biggestLoop = Math.max(biggestLoop, GPUTickList[i]);
  206.         }
  207.         //Get number of input/target patterns
  208.         int numPatterns = Math.min(pSet.inputPatterns.size, pSet.targetPatterns.size);
  209.        
  210.         if(GPUNumInputs == null || GPUNumInputs.length < net.getNeuronCount())
  211.         {
  212.             //Set up stuff that only depends on number of nodes
  213.             //Number of inputs to each node
  214.             GPUNumInputs = new int[net.getNeuronCount()];
  215.             //Activation functions for each node
  216.             GPUActFuncs = new int[net.getNeuronCount()];
  217.             //Bias for each node
  218.             GPUBias = new double[net.getNeuronCount()];
  219.             GPUUsingBias = new int[net.getNeuronCount()];
  220.             GPULastBiasChange = new double[net.getNeuronCount()];
  221.             //Is input/output for each node
  222.             GPUIsInputNeuron = new int[net.getNeuronCount()];
  223.             GPUIsOutputNeuron = new int[net.getNeuronCount()];
  224.         }
  225.        
  226.         //Setup 2D arrays
  227.         //Set up input/target arrays based on number of patterns.
  228.         GPUInputPatterns = new double[pSet.getInputPatternSize() * numPatterns];
  229.         GPUTargetPatterns = new double[pSet.getTargetPatternSize() * numPatterns];
  230.        
  231.         GPUInputPatternSize = new int[] {pSet.getInputPatternSize()};
  232.         GPUTargetPatternSize = new int[] {pSet.getTargetPatternSize()};
  233.         //Populate arrays with pattern data
  234.         for(int i = 0; i < numPatterns; i++)
  235.         {
  236.             java.lang.System.arraycopy(pSet.inputPatterns.get(i), 0,
  237.                     GPUInputPatterns, i * net.getInputNeuronCount(),
  238.                     net.getInputNeuronCount());
  239.             java.lang.System.arraycopy(pSet.targetPatterns.get(i), 0,
  240.                     GPUTargetPatterns, i * net.getOutputNeuronCount(),
  241.                     net.getOutputNeuronCount());
  242.         }
  243.        
  244.         //Set up 1D arrays for each node, and work out MaxNumFloats
  245.         GPUMaxNumFloats[0] = 0;
  246.         for(int i = 0; i < net.getNeuronCount(); i++)
  247.         {
  248.             Neuron neuron = net.getNeuron(i);
  249.             //Set fields to vals for each neuron
  250.             GPUNumInputs[i] = neuron.getInputLinkCount();
  251.             GPUActFuncs[i] = neuron.getActivationFn().getClass() ==
  252.                                 Function.Sigmoid.class ? 1 : 0;
  253.             GPUBias[i] = neuron.getBias();
  254.             GPUUsingBias[i] = neuron.isUsingBias() ? 1 : 0;
  255.             GPUIsInputNeuron[i] = neuron.isInputNeuron() ? 1 : 0;
  256.             GPUIsOutputNeuron[i] = neuron.isOutputNeuron() ? 1 : 0;
  257.            
  258.             //Check for GPUMaxNumFloats
  259.             GPUMaxNumFloats[0] = Math.max(GPUMaxNumFloats[0], GPUNumInputs[i]);
  260.         }
  261.        
  262.         //Having worked out MaxNumFloats, work out 2D arrays
  263.         GPUWeights = new double[net.getNeuronCount() * GPUMaxNumFloats[0]];
  264.         GPULastWeightChange = new double[net.getNeuronCount() * GPUMaxNumFloats[0]];
  265.         GPUInputs = new int[net.getNeuronCount() * GPUMaxNumFloats[0]];
  266.        
  267.         for(int i = 0; i < net.getNeuronCount(); i++)
  268.         {
  269.             Neuron neuron = net.getNeuron(i);
  270.             for(int j = 0; j < GPUMaxNumFloats[0]; j++)
  271.             {
  272.                 if(j < neuron.getInputLinkCount())
  273.                 {
  274.                     Link link = neuron.getInputLink(j);
  275.                     GPUWeights[(i * GPUMaxNumFloats[0]) + j] = link.getWeight();
  276.                     GPUInputs[(i * GPUMaxNumFloats[0]) + j] = net.getNeuronIndex(link.getSource());
  277.                 }
  278.                 else
  279.                 {
  280.                     GPUWeights[(i * GPUMaxNumFloats[0]) + j] = 0;
  281.                     GPUInputs[(i * GPUMaxNumFloats[0]) + j] = 0;
  282.                 }
  283.             }
  284.         }
  285.        
  286.         ///////////
  287.         //GPU STUFF
  288.        
  289.         //Setup opencl context and command queue if not already done.
  290.         if(clCommandQueue == null || clContext == null)
  291.             SetupCL();
  292.        
  293.         //Setup Buffers and programs
  294.         if(memObjects == null)
  295.         {
  296.             memObjects = new cl_mem[25];
  297.            
  298.             //Calculate
  299.             //Input Patterns
  300.             memObjects[0] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  301.                     Sizeof.cl_double * GPUInputPatterns.length,
  302.                     Pointer.to(GPUInputPatterns), null);
  303.             //Target Patterns
  304.             memObjects[1] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  305.                     Sizeof.cl_double * GPUTargetPatterns.length,
  306.                     Pointer.to(GPUTargetPatterns), null);
  307.             //Initial Weights
  308.             memObjects[2] = clCreateBuffer(clContext, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
  309.                     Sizeof.cl_double * GPUWeights.length,
  310.                     Pointer.to(GPUWeights), null);
  311.             //Number of Inputs
  312.             memObjects[3] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  313.                     Sizeof.cl_int * GPUNumInputs.length,
  314.                     Pointer.to(GPUNumInputs), null);
  315.             //Activation functions
  316.             memObjects[4] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  317.                     Sizeof.cl_int * GPUActFuncs.length,
  318.                     Pointer.to(GPUActFuncs), null);
  319.             //Bias
  320.             memObjects[5] = clCreateBuffer(clContext, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
  321.                     Sizeof.cl_double * GPUBias.length,
  322.                     Pointer.to(GPUBias), null);
  323.             //Values
  324.             memObjects[6] = clCreateBuffer(clContext, CL_MEM_READ_WRITE,
  325.                     Sizeof.cl_double * net.getNeuronCount(), null, null);
  326.             //Max num floats
  327.             memObjects[7] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  328.                     Sizeof.cl_int,
  329.                     Pointer.to(GPUMaxNumFloats), null);
  330.             //Current Pattern - Device
  331.             memObjects[8] = clCreateBuffer(clContext, CL_MEM_READ_ONLY,
  332.                     Sizeof.cl_int, null, null);
  333.             //Current Pattern - PINNED
  334.             memObjects[9] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR,
  335.                     Sizeof.cl_int, null, null);
  336.             //Input Pattern Size
  337.             memObjects[10] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  338.                     Sizeof.cl_int, Pointer.to(GPUInputPatternSize), null);
  339.             //Target Pattern Size
  340.             memObjects[11] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  341.                     Sizeof.cl_int, Pointer.to(GPUTargetPatternSize), null);
  342.             //Index Offset - Device
  343.             memObjects[12] = clCreateBuffer(clContext, CL_MEM_READ_ONLY,
  344.                     Sizeof.cl_int, null, null);
  345.             //Index Offset - PINNED
  346.             memObjects[13] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR,
  347.                     Sizeof.cl_int, null, null);
  348.             //Input Neuron?
  349.             memObjects[14] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  350.                     Sizeof.cl_int * GPUIsInputNeuron.length, Pointer.to(GPUIsInputNeuron), null);
  351.             //Output Neuron?
  352.             memObjects[15] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  353.                     Sizeof.cl_int * GPUIsOutputNeuron.length, Pointer.to(GPUIsOutputNeuron), null);
  354.             //Inputs
  355.             memObjects[16] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  356.                     Sizeof.cl_int * GPUInputs.length, Pointer.to(GPUInputs), null);
  357.             //Using bias
  358.             memObjects[18] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  359.                     Sizeof.cl_int * GPUUsingBias.length, Pointer.to(GPUUsingBias), null);
  360.  
  361.             GPUMapCurrentPattern = clEnqueueMapBuffer(clCommandQueue, memObjects[9], CL_FALSE,
  362.                     CL_MAP_WRITE, 0, Sizeof.cl_int, 0, null, null, null);
  363.             GPUMapIndexOffset = clEnqueueMapBuffer(clCommandQueue, memObjects[13], CL_FALSE,
  364.                     CL_MAP_WRITE, 0, Sizeof.cl_int, 0, null, null, null);
  365.            
  366.             ///////////
  367.             //BackProp
  368.             //Error
  369.             memObjects[17] = clCreateBuffer(clContext, CL_MEM_READ_WRITE,
  370.                     Sizeof.cl_double * GPUUsingBias.length, null, null);
  371.             //MinError
  372.             memObjects[19] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  373.                     Sizeof.cl_double, Pointer.to(GPUMinError), null);
  374.             //Learning Rate
  375.             memObjects[20] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  376.                     Sizeof.cl_double, Pointer.to(GPULearningRate), null);
  377.             //Last Bias Change
  378.             memObjects[21] = clCreateBuffer(clContext, CL_MEM_READ_WRITE,
  379.                     Sizeof.cl_double * GPUUsingBias.length, null, null);
  380.             //Momentum
  381.             memObjects[22] = clCreateBuffer(clContext, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
  382.                     Sizeof.cl_double, Pointer.to(GPUMomentum), null);
  383.             //Last Weight Change
  384.             memObjects[23] = clCreateBuffer(clContext, CL_MEM_READ_WRITE,
  385.                     Sizeof.cl_double * GPUWeights.length, null, null);
  386.            
  387.         }
  388.        
  389.        
  390.         if(program_iterate != null)
  391.             clReleaseProgram(program_iterate);
  392.         if(program_errorcalc != null)
  393.             clReleaseProgram(program_errorcalc);
  394.         if(program_backprop != null)
  395.             clReleaseProgram(program_backprop);
  396.         if(kernel_iterate != null)
  397.             clReleaseKernel(kernel_iterate);
  398.         if(kernel_errorcalc != null)
  399.             clReleaseKernel(kernel_errorcalc);
  400.         if(kernel_backprop != null)
  401.             clReleaseKernel(kernel_backprop);
  402.        
  403.         program_iterate = clCreateProgramWithSource(clContext, 1,
  404.                 new String[] { strKernIterate }, null, null);
  405.         clBuildProgram(program_iterate, 0, null, null, null, null);
  406.         program_errorcalc = clCreateProgramWithSource(clContext, 1,
  407.                 new String[] { strKernErrorCalc }, null, null);
  408.         clBuildProgram(program_errorcalc, 0, null, null, null, null);
  409.         program_backprop = clCreateProgramWithSource(clContext, 1,
  410.                 new String[] { strKernBackProp }, null, null);
  411.         clBuildProgram(program_backprop, 0, null, null, null, null);
  412.        
  413.         kernel_iterate = clCreateKernel(program_iterate, "Neuron", null);
  414.         kernel_errorcalc = clCreateKernel(program_errorcalc, "Neuron", null);
  415.         kernel_backprop = clCreateKernel(program_backprop, "Neuron", null);
  416.        
  417.         //Set kernel args
  418.         //Input Patterns
  419.         clSetKernelArg(kernel_iterate, 0, Long.valueOf(Sizeof.cl_mem),
  420.                 Pointer.to(memObjects[0]));
  421.         //Target patterns
  422.         //clSetKernelArg(kern_calc, 1, Long.valueOf(Sizeof.cl_mem),
  423.                 //Pointer.to(memObjects[1]));
  424.         //Initial Weights
  425.         clSetKernelArg(kernel_iterate, 1, Long.valueOf(Sizeof.cl_mem),
  426.                 Pointer.to(memObjects[2]));
  427.         //Number of inputs
  428.         clSetKernelArg(kernel_iterate, 2, Long.valueOf(Sizeof.cl_mem),
  429.                 Pointer.to(memObjects[3]));
  430.         //Activation functions
  431.         clSetKernelArg(kernel_iterate, 3, Long.valueOf(Sizeof.cl_mem),
  432.                 Pointer.to(memObjects[4]));
  433.         //Bias
  434.         clSetKernelArg(kernel_iterate, 4, Long.valueOf(Sizeof.cl_mem),
  435.                 Pointer.to(memObjects[5]));
  436.         //Using Bias
  437.         clSetKernelArg(kernel_iterate, 5, Long.valueOf(Sizeof.cl_mem),
  438.                 Pointer.to(memObjects[18]));
  439.         //Values
  440.         clSetKernelArg(kernel_iterate, 6, Long.valueOf(Sizeof.cl_mem),
  441.                 Pointer.to(memObjects[6]));
  442.         //Max Num Floats
  443.         clSetKernelArg(kernel_iterate, 7, Long.valueOf(Sizeof.cl_mem),
  444.                 Pointer.to(memObjects[7]));
  445.         //Current Pattern
  446.         clSetKernelArg(kernel_iterate, 8, Long.valueOf(Sizeof.cl_mem),
  447.                 Pointer.to(memObjects[8]));
  448.         //InputPatternSize
  449.         clSetKernelArg(kernel_iterate, 9, Long.valueOf(Sizeof.cl_mem),
  450.                 Pointer.to(memObjects[10]));
  451.         //TargetPatternSize
  452.         //clSetKernelArg(kern_calc, 11, Long.valueOf(Sizeof.cl_mem),
  453.                 //Pointer.to(memObjects[11]));
  454.         //Index Offset
  455.         clSetKernelArg(kernel_iterate, 10, Long.valueOf(Sizeof.cl_mem),
  456.                 Pointer.to(memObjects[12]));
  457.         //Is input neuron
  458.         clSetKernelArg(kernel_iterate, 11, Long.valueOf(Sizeof.cl_mem),
  459.                 Pointer.to(memObjects[14]));
  460.         //Is output neuron
  461.         //clSetKernelArg(kern_calc, 14, Long.valueOf(Sizeof.cl_mem),
  462.                 //Pointer.to(memObjects[15]));
  463.         //Inputs
  464.         clSetKernelArg(kernel_iterate, 12, Long.valueOf(Sizeof.cl_mem),
  465.                 Pointer.to(memObjects[16]));
  466.        
  467.         //////////
  468.         //Error calc kernel arguments
  469.         //Target patterns
  470.         clSetKernelArg(kernel_errorcalc, 0, Long.valueOf(Sizeof.cl_mem),
  471.                 Pointer.to(memObjects[1]));
  472.         //Values
  473.         clSetKernelArg(kernel_errorcalc, 1, Long.valueOf(Sizeof.cl_mem),
  474.                 Pointer.to(memObjects[6]));
  475.         //MinError
  476.         clSetKernelArg(kernel_errorcalc, 2, Long.valueOf(Sizeof.cl_mem),
  477.                 Pointer.to(memObjects[19]));
  478.         //Index Offset
  479.         clSetKernelArg(kernel_errorcalc, 3, Long.valueOf(Sizeof.cl_mem),
  480.                 Pointer.to(memObjects[12]));
  481.         //Pattern Index
  482.         clSetKernelArg(kernel_errorcalc, 4, Long.valueOf(Sizeof.cl_mem),
  483.                 Pointer.to(memObjects[8]));
  484.         //Target Pattern Size
  485.         clSetKernelArg(kernel_errorcalc, 5, Long.valueOf(Sizeof.cl_mem),
  486.                 Pointer.to(memObjects[11]));
  487.         //Error
  488.         clSetKernelArg(kernel_errorcalc, 6, Long.valueOf(Sizeof.cl_mem),
  489.                 Pointer.to(memObjects[17]));
  490.        
  491.         //////////
  492.         //BackProp
  493.         //Using Bias
  494.         clSetKernelArg(kernel_backprop, 0, Long.valueOf(Sizeof.cl_mem),
  495.                 Pointer.to(memObjects[18]));
  496.         //indexOffset
  497.         clSetKernelArg(kernel_backprop, 1, Long.valueOf(Sizeof.cl_mem),
  498.                 Pointer.to(memObjects[12]));
  499.         //Learning Rate
  500.         clSetKernelArg(kernel_backprop, 2, Long.valueOf(Sizeof.cl_mem),
  501.                 Pointer.to(memObjects[20]));
  502.         //Error
  503.         clSetKernelArg(kernel_backprop, 3, Long.valueOf(Sizeof.cl_mem),
  504.                 Pointer.to(memObjects[17]));
  505.         //Values
  506.         clSetKernelArg(kernel_backprop, 4, Long.valueOf(Sizeof.cl_mem),
  507.                 Pointer.to(memObjects[6]));
  508.         //Momentum
  509.         clSetKernelArg(kernel_backprop, 5, Long.valueOf(Sizeof.cl_mem),
  510.                 Pointer.to(memObjects[22]));
  511.         //lastBiasChange
  512.         clSetKernelArg(kernel_backprop, 6, Long.valueOf(Sizeof.cl_mem),
  513.                 Pointer.to(memObjects[21]));
  514.         //lastWeightChange
  515.         clSetKernelArg(kernel_backprop, 7, Long.valueOf(Sizeof.cl_mem),
  516.                 Pointer.to(memObjects[23]));
  517.         //Num inputs
  518.         clSetKernelArg(kernel_backprop, 8, Long.valueOf(Sizeof.cl_mem),
  519.                 Pointer.to(memObjects[3]));
  520.         //Weights
  521.         clSetKernelArg(kernel_backprop, 9, Long.valueOf(Sizeof.cl_mem),
  522.                 Pointer.to(memObjects[2]));
  523.         //Max Num Floats
  524.         clSetKernelArg(kernel_backprop, 10, Long.valueOf(Sizeof.cl_mem),
  525.                 Pointer.to(memObjects[7]));
  526.         //Inputs
  527.         clSetKernelArg(kernel_backprop, 11, Long.valueOf(Sizeof.cl_mem),
  528.                 Pointer.to(memObjects[16]));
  529.         //Bias
  530.         clSetKernelArg(kernel_backprop, 12, Long.valueOf(Sizeof.cl_mem),
  531.                 Pointer.to(memObjects[5]));
  532.        
  533.         clFinish(clCommandQueue);
  534.        
  535.         long global_work_offset[] = new long[]{0};
  536.         long global_work_size[] = new long[1];
  537.         long local_work_size[] = new long[]{1};
  538.        
  539.         int ret = 0;
  540.        
  541.         //Loop over each pattern n times
  542.         for(int i = 0; i < eachTimes; i++)
  543.             for(int j = 0; j < numPatterns; j++)
  544.             {
  545.                 //Set current pattern index
  546.                 GPUMapCurrentPattern.asIntBuffer().put(0, j);
  547.                 ret = clEnqueueWriteBuffer(clCommandQueue, memObjects[8], CL_TRUE, 0,
  548.                         Sizeof.cl_int, Pointer.to(GPUMapCurrentPattern), 0, null, null);
  549.                
  550.                 //Calc
  551.                 for(int k = 0; k < GPUTickList.length; k++)
  552.                 {
  553.                     clFlush(clCommandQueue);
  554.                     clFinish(clCommandQueue);
  555.                     //If input nodes
  556.                     if(k == 0)
  557.                         //Set index offset to 0
  558.                         GPUMapIndexOffset.asIntBuffer().put(0, 0);
  559.                     else
  560.                         //Update index offset
  561.                         GPUMapIndexOffset.asIntBuffer().put(0,
  562.                             GPUMapIndexOffset.asIntBuffer().get(0) + GPUTickList[k-1]);
  563.                     //Write index offset to GPU buffer
  564.                     ret = clEnqueueWriteBuffer(clCommandQueue, memObjects[12], CL_TRUE, 0,
  565.                             Sizeof.cl_int, Pointer.to(GPUMapIndexOffset.position(0)), 0, null, null);
  566.                    
  567.                     //Set work size (width of layer)
  568.                     global_work_size[0] = GPUTickList[k];
  569.                     ret = clEnqueueNDRangeKernel(clCommandQueue, kernel_iterate, 1,
  570.                         global_work_offset, global_work_size, local_work_size,
  571.                         0, null, null);
  572.                     //clFinish(clCommandQueue);
  573.                    
  574.                     //clEnqueueReadBuffer(clCommandQueue, memObjects[6], CL_TRUE, 0,
  575.                     //        Sizeof.cl_double * net.getNeuronCount(),
  576.                     //        Pointer.to(GPUBias), 0, null, null);
  577.                 }
  578.                 //Error calc
  579.                 clFlush(clCommandQueue);
  580.                 clFinish(clCommandQueue);
  581.                 ret = clEnqueueNDRangeKernel(clCommandQueue, kernel_errorcalc, 1,
  582.                     global_work_offset, global_work_size, local_work_size,
  583.                     0, null, null);
  584.                
  585.                 //Back prop
  586.                 for(int k = GPUTickList.length -1; k >= 0; k--)
  587.                 {
  588.                     global_work_size[0] = GPUTickList[k];
  589.                     //Do backprop
  590.                     clFlush(clCommandQueue);
  591.                     clFinish(clCommandQueue);
  592.                     ret = clEnqueueNDRangeKernel(clCommandQueue, kernel_backprop, 1,
  593.                             global_work_offset, global_work_size, local_work_size,
  594.                             0, null, null);
  595.                     //If there is another level to go, update offset
  596.                     if(k > 0)
  597.                         GPUMapIndexOffset.asIntBuffer().put(0,
  598.                             GPUMapIndexOffset.asIntBuffer().get(0) - GPUTickList[k-1]);
  599.                 }
  600.                
  601.             }
  602.        
  603.         //Read Weights
  604.         clFlush(clCommandQueue);
  605.         clFinish(clCommandQueue);
  606.         ret = clEnqueueReadBuffer(clCommandQueue, memObjects[2], CL_TRUE, 0,
  607.                 Sizeof.cl_double * GPUWeights.length, Pointer.to(GPUWeights),
  608.                 0, null, null);
  609.         //Read Bias
  610.         ret = clEnqueueReadBuffer(clCommandQueue, memObjects[5], CL_TRUE, 0,
  611.                 Sizeof.cl_double * GPUBias.length, Pointer.to(GPUBias),
  612.                 0, null, null);
  613.        
  614.         for(int i = 0; i < net.getNeuronCount(); i++)
  615.         {
  616.             Neuron neuron = net.getNeuron(i);
  617.             neuron.setBias(GPUBias[i]);
  618.             for(int j = 0; j < neuron.getInputLinkCount(); j++)
  619.             {
  620.                 Link link = neuron.getInputLink(j);
  621.                 link.setWeight(GPUWeights[(i * GPUMaxNumFloats[0]) + j]);
  622.             }
  623.         }
  624.         setTraining(false);
  625.     }
  626.    
  627.     protected int SetupCL()
  628.     {
  629.         try
  630.         {
  631.             final int platformIndex = 0;
  632.             final long deviceType = CL_DEVICE_TYPE_GPU;
  633.             final int deviceIndex = 0;
  634.  
  635.             CL.setExceptionsEnabled(true);
  636.  
  637.             //Obtain number of platforms
  638.             int numPlatformsArray[] = new int[1];
  639.             clGetPlatformIDs(0, null, numPlatformsArray);
  640.             int numPlatforms = numPlatformsArray[0];
  641.  
  642.             //Obtain a platform ID
  643.             cl_platform_id platforms[] = new cl_platform_id[numPlatforms];
  644.             clGetPlatformIDs(platforms.length, platforms, null);
  645.             cl_platform_id platform = platforms[platformIndex];
  646.  
  647.             //Initialize the context properties
  648.             cl_context_properties contextProperties = new cl_context_properties();
  649.             contextProperties.addProperty(CL_CONTEXT_PLATFORM, platform);
  650.  
  651.             //Obtain the number of devices for the platform
  652.             int numDevicesArray[] = new int[1];
  653.             clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray);
  654.             int numDevices = numDevicesArray[0];
  655.  
  656.             //Obtain a device ID
  657.             cl_device_id devices[] = new cl_device_id[numDevices];
  658.             clGetDeviceIDs(platform, deviceType, numDevices, devices, null);
  659.             cl_device_id device = devices[deviceIndex];
  660.  
  661.             //Create a context for the selected device
  662.             clContext = clCreateContext(contextProperties, 1, new cl_device_id[]{device},
  663.                     null, null, null);
  664.  
  665.             //Create a command-queue for the selected device
  666.             clCommandQueue = clCreateCommandQueue(clContext, device, 0, null);
  667.         } catch (CLException e) {
  668.            
  669.         }
  670.         return 0;
  671.     }
  672.     //
  673.     // creation & persistance
  674.    
  675.     /**
  676.      *  Create a new Backprop trainer.
  677.      *  Uses a {@link boone.TrainingSignalGenerator.SquareError} .
  678.      */
  679.     public
  680.     BackpropTrainer()
  681.     {
  682.         setTrainingSignalGenerator(new TrainingSignalGenerator.SquareError());
  683.     }
  684.    
  685.     /** create standard NeuronTrainers, LinkTrainers. */
  686.     public
  687.     PartTrainer createPartTrainer(BrainPart part)
  688.     {
  689.         if(part instanceof Neuron)
  690.             return new BackpropNeuronTrainer((Neuron) part, this);
  691.         if(part instanceof Link)
  692.             return new BackpropLinkTrainer((Link) part, this);
  693.         return null;
  694.     }
  695.    
  696.  
  697.     //
  698.     // training
  699.    
  700.     /**
  701.      *  supervised training. an input and an output pattern are given. <p>
  702.      *  
  703.      *  Backpropagation Training: <ul>
  704.      *      <li> set the training flag
  705.      *      <li> test the network
  706.      *      <li> calculate the error of the output neurons, train them
  707.      *      <li> propagate backwards through the network, using the reversed tickList.
  708.      *      <li> unset the training flag.
  709.      *      <li> return the error. </ul>
  710.      *
  711.      *  This method uses a {@link boone.TrainingSignalGenerator} to calculate the error signals.
  712.      *  Please see {@link boone.TrainingSignalGenerator.SquareError} for details on what the
  713.      *  TrainingSignalGenerator instance is supposed to do.
  714.      */
  715.     public
  716.     void trainTurn(double[] input, double[] target)
  717.     {
  718.         setTraining(true);
  719.        
  720.         // initialize all the neurons
  721.         for(int i=net.getNeuronCount()-1; i>=0; i--)
  722.             ((BackpropNeuronTrainer) net.getNeuron(i).getPartTrainer()).beginTurn();
  723.        
  724.         // test the network, calculating the error.
  725.         net.setInput(input);
  726.         net.innervate();
  727.        
  728.         // calculate neuron error, error sum for output neurons
  729.         trainingSignalGenerator.calculateSignal(net, input, target, minError);
  730.        
  731.         // train: propagate backwards
  732.         VarArray<Neuron> tlist = net.getTickList();
  733.         for(int i=tlist.size-1; i>=0; i--)
  734.         {
  735.             PartTrainer partTrainer = ((Neuron) tlist.array[i]).partTrainer;
  736.             if(partTrainer != null)
  737.                 partTrainer.train();
  738.         }
  739.        
  740.         setTraining(false);
  741.     }
  742.    
  743.  
  744.    
  745.     //
  746.     // getters & setters
  747.    
  748.    
  749.     /** @return Value of property minError. */
  750.     public double getMinError() { return minError; }
  751.  
  752.     /** @param minError New value of property minError. */
  753.     public void setMinError(double minError) { this.minError = minError; }
  754.  
  755.     /** @return Value of property momentum. */
  756.     public double getMomentum() { return momentum; }    
  757.    
  758.     /** @param momentum New value of property momentum. */
  759.     public void setMomentum(double momentum) { this.momentum = momentum; }
  760.    
  761.    
  762.     //
  763.     // storing, loading
  764.    
  765.     /** Store the state into the IOElement. */
  766.     public
  767.     void store(IOElement node)
  768.     {
  769.         super.store(node);
  770.         node.putAttribute("minError", minError);
  771.         node.putAttribute("momentum", momentum);
  772.     }
  773.    
  774.     /** Load the state from the IOElement. */
  775.     public
  776.     void load(IOElement node) throws IOElement.LoadException
  777.     {
  778.         super.load(node);
  779.         minError = node.getDoubleAttribute("minError", minError);
  780.         momentum = node.getDoubleAttribute("momentum", momentum);
  781.     }
  782.    
  783.     //
  784.     //
  785.     // subclasses
  786.     //
  787.    
  788.     /**
  789.      * PartTrainer for Neurons
  790.      * @author August Mayer
  791.      */
  792.     public static
  793.     class BackpropNeuronTrainer
  794.     extends PartTrainer
  795.     {
  796.        
  797.         /** neuron error term */
  798.         protected double errorSignal = 0.0;
  799.        
  800.         /** last bias weight change */
  801.         protected double lastBiasChange = 0;
  802.        
  803.         //
  804.         // construction, persistence
  805.        
  806.         /** create a new null trainer. Needed by persistence. */
  807.         public BackpropNeuronTrainer() {}
  808.        
  809.         /** new NeuronTrainer. */
  810.         public BackpropNeuronTrainer(Neuron neuron, BackpropTrainer trainer) { super(neuron, trainer); }
  811.        
  812.        
  813.         //
  814.         // Training
  815.        
  816.         /** reset lastBiasChange to 0 */
  817.         public void resetTraining() { lastBiasChange = 0.0; }
  818.        
  819.         /** reset the error signal each turn */
  820.         public void beginTurn() { errorSignal = 0.0; }
  821.        
  822.         /**
  823.          *  train the neuron. <br>
  824.          *  This is called in Neuron.run() as well as in the Trainer.
  825.          */
  826.         public
  827.         void train()
  828.         {
  829.             Neuron neuron = (Neuron) part;
  830.            
  831.             // train the bias (if the neuron uses bias)
  832.             if(neuron.isUsingBias())
  833.             {
  834.                 BackpropTrainer bptrainer = (BackpropTrainer) trainer;
  835.                 double biasChange = bptrainer.learnRate
  836.                     * errorSignal
  837.                     * neuron.getActivationFn().mapDerivative(neuron.getInput())
  838.                     + bptrainer.momentum * lastBiasChange;
  839.                 neuron.addToBias(biasChange);
  840.                 lastBiasChange = biasChange;
  841.             }
  842.  
  843.             // train the preceding links
  844.             for(int i=neuron.getInputLinkCount()-1; i>=0; i--)
  845.             {
  846.                 PartTrainer partTrainer = neuron.getInputLink(i).partTrainer;
  847.                 if(partTrainer != null)
  848.                     partTrainer.train();
  849.             }
  850.         }
  851.        
  852.        
  853.         //
  854.         // getters & setters
  855.        
  856.         /** return the current errorSignal */
  857.         public double getErrorSignal() { return errorSignal; }
  858.        
  859.         /** set the error signal */
  860.         public void setErrorSignal(double errorSignal) { this.errorSignal = errorSignal; }
  861.        
  862.         /** return the last bias change. */
  863.         public double getLastBiasChange() { return lastBiasChange; }
  864.        
  865.         /**  set the last bias change */
  866.         public void setLastBiasChange(double changeVal) { this.lastBiasChange = changeVal; }
  867.        
  868.         /** Store the state. */
  869.         public
  870.         void store(IOElement node)
  871.         {
  872.             node.putAttribute("errorSignal", errorSignal);
  873.             node.putAttribute("lastBiasChange", lastBiasChange);
  874.         }
  875.        
  876.         /** Load the state. */
  877.         public
  878.         void load(IOElement node)
  879.         {
  880.             errorSignal = node.getDoubleAttribute("errorSignal", errorSignal);
  881.             lastBiasChange = node.getDoubleAttribute("lastBiasChange", lastBiasChange);
  882.         }
  883.     }
  884.    
  885.    
  886.    
  887.     /**
  888.      *  PartTrainer for Links
  889.      *  @author August Mayer
  890.      */
  891.     public static
  892.     class BackpropLinkTrainer extends PartTrainer
  893.     {
  894.        
  895.         /** last link weight change */
  896.         protected double lastWeightChange = 0;
  897.  
  898.         /** create a new null object. Needed by persistence. */
  899.         public BackpropLinkTrainer() {}
  900.        
  901.         /** Creates a new instance of BPLinkTrainer */
  902.         public BackpropLinkTrainer(Link link, BackpropTrainer trainer) { super(link, trainer); }
  903.  
  904.         /** reset the lastWeightChange to 0 */
  905.         public void resetTraining() { lastWeightChange = 0; }
  906.        
  907.         /**
  908.          *  Train the link. <br>
  909.          *  This is called in Neuron.run() as well as in the NetTrainer.
  910.          */
  911.         public
  912.         void train()
  913.         {
  914.             Link link = (Link) part;
  915.             BackpropTrainer bptrainer = (BackpropTrainer) trainer;
  916.             Neuron src = link.getSource();
  917.             BackpropNeuronTrainer srcTrainer = (BackpropNeuronTrainer) src.partTrainer;
  918.             Neuron sink = link.getSink();
  919.             BackpropNeuronTrainer sinkTrainer = (BackpropNeuronTrainer) sink.partTrainer;
  920.            
  921.             // get the sink error
  922.             double completeErrorSignal = sinkTrainer.getErrorSignal()
  923.                 * sink.getActivationFn().mapDerivative(sink.getInput());
  924.            
  925.             double weightedErrorSignal = completeErrorSignal * link.getWeight();
  926.            
  927.             // propagate error to source neuron:
  928.             srcTrainer.errorSignal += weightedErrorSignal;
  929.  
  930.             // modify the link weight
  931.             double weightChange = bptrainer.learnRate
  932.                 * completeErrorSignal
  933.                 * src.getOutput()
  934.                 + bptrainer.momentum * lastWeightChange;
  935.             link.addToWeight(weightChange);
  936.  
  937.             lastWeightChange = weightChange;
  938.         }
  939.        
  940.         /** return the last link weight change. */
  941.         public double getLastWeightChange() { return lastWeightChange; }
  942.  
  943.         /**  set the last link weight change */
  944.         public void setLastWeightChange(double changeVal) { this.lastWeightChange = changeVal; }
  945.  
  946.         /** Store the state. */
  947.         public
  948.         void store(IOElement node)
  949.         {
  950.             node.putAttribute("lastWeightChange", lastWeightChange);
  951.         }
  952.        
  953.         /** Load the state. */
  954.         public
  955.         void load(IOElement node)
  956.         {
  957.             lastWeightChange = node.getDoubleAttribute("lastWeightChange", lastWeightChange);
  958.         }
  959.        
  960.     } // end LinkTrainer
  961.    
  962. } // end BackpropTrainer
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement