suremarc

MLP

Mar 18th, 2016
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 4.64 KB | None | 0 0
  1. debugMode = true;
  2. if not debugMode then
  3.     print = function()end;
  4. end
  5.  
  6. local exp, max = math.exp, math.max;
  7. local K = function(z) return 1/(1+exp(-z)); end;
  8. -- multilayer perceptron neural network
  9. return function(N, E, dE)
  10.     --[[
  11.             N is an array whose elements are positive integers
  12.             E is the error function
  13.             dE is the total derivative of the error function
  14.     ]]
  15.     local m = #N;
  16.    
  17.     -- retrieve memory address of an object
  18.     -- assigns names to nodes for identification
  19.     local function getHex(o)
  20.         return tonumber(tostring(o):sub(-8),16);
  21.     end
  22.    
  23.     local nodes = {};
  24.     local weights = {};
  25.     local values = {};  -- value at each node
  26.     local mirror;       -- serialization
  27.     local myInput;
  28.    
  29.     nodes[1] = {};
  30.     for j = 1, N[1] do
  31.         local node;
  32.         node = function()
  33.             values[node] = myInput[j];
  34.             --print('firing node: ' .. 1 .. ', ' .. j .. ' @ ' .. values[node]);
  35.         end;
  36.         nodes[1][j] = node;
  37.         weights[node] = {};
  38.     end
  39.    
  40.     -- bias node
  41.     local bias;
  42.     bias = function()
  43.         values[bias] = 1;
  44.     end;
  45.     bias();
  46.     weights[bias] = {};
  47.    
  48.     for i = 2, m do
  49.         nodes[i] = {};
  50.         values[i] = {};
  51.         for j = 1, N[i] do
  52.             local last = nodes[i-1];
  53.             local node;
  54.             node =  function()
  55.                         local sum = 0;
  56.                         for k, linked in ipairs(last) do
  57.                             sum = sum + weights[linked][node]*values[linked];
  58.                         end
  59.                         values[node] = K(sum+weights[bias][node]);
  60.                         --print('firing node: ' .. i .. ', ' .. j .. ' @ ' .. values[node]);
  61.                     end;
  62.             nodes[i][j] = node;
  63.             weights[node] = {};
  64.             for k, linked in ipairs(last) do
  65.                 weights[linked][node] = 0;
  66.             end
  67.             weights[bias][node] = 0;
  68.         end
  69.     end
  70.    
  71.     local hexToNode = {}; -- get node from mem. address
  72.     local index = {}; -- indices from object
  73.     for i = 1, m do
  74.         for j = 1, #nodes[i] do
  75.             hexToNode[getHex(nodes[i][j])] = nodes[i][j];
  76.             index[nodes[i][j]] = {i, j};
  77.         end
  78.     end
  79.     hexToNode[getHex(bias)] = bias;
  80.    
  81.     local function getIndex(node)
  82.         return unpack(index[node]);
  83.     end
  84.    
  85.     local function updateMirror() -- updates serialization
  86.         for hex, data in pairs(mirror) do
  87.             local node = hexToNode[hex];
  88.             data[2] = values[node];
  89.             local w = data[3];
  90.             for i, j in pairs(w) do
  91.                 w[i] = nil;
  92.             end
  93.             for node, val in pairs(weights[node]) do
  94.                 w[getHex(node)] = val;
  95.             end
  96.         end
  97.     end
  98.        
  99.    
  100.     local function fire(i) -- feedforward from layer 'i'
  101.         i = i or 1;
  102.         for j = i, m do
  103.             for k, node in ipairs(nodes[j]) do
  104.                 node();
  105.             end
  106.         end
  107.         updateMirror();
  108.     end;
  109.    
  110.     local function doInput(X)
  111.         myInput = X;
  112.         fire(1);
  113.     end
  114.    
  115.    
  116.     local delta = {};
  117.     for i = 1, m do
  118.         delta[i] = {};
  119.         for j = 1, #nodes[i] do
  120.             delta[nodes[i][j]] = 0;
  121.         end
  122.     end
  123.     delta[bias] = 0; -- always zero
  124.    
  125.     local function d(t, i) -- compute der. coefficients backwards to layer 'i'
  126.         i = i or 1;
  127.         local val;
  128.         for k, node in ipairs(nodes[m]) do
  129.             val = values[node];
  130.             delta[node] = dE(t[k], val)*val*(1-val);
  131.         end
  132.         local sum;
  133.         for j = m-1, i, -1 do
  134.             for k, node in ipairs(nodes[j]) do
  135.                 sum = 0;
  136.                 for linked, w in pairs(weights[node]) do
  137.                     sum = sum + delta[linked]*w
  138.                 end
  139.                 delta[node] = sum*values[node]*(1-values[node]);
  140.             end
  141.         end
  142.     end
  143.    
  144.     local function learn(input, target, eps, alpha, timeout) -- train with input and target
  145.         doInput(input);
  146.         local err = {};
  147.         repeat
  148.             d(target);
  149.             for n0, myWeights in pairs(weights) do
  150.                 for n1, w in pairs(myWeights) do
  151.                     myWeights[n1] = w-alpha*delta[n1]*values[n0];
  152.                 end
  153.             end
  154.             fire(1);
  155.             for i, node in ipairs(nodes[m]) do
  156.                 err[i] = E(target[i], values[node]);
  157.             end
  158.         until max(unpack(err))<eps;
  159.     end
  160.    
  161.     mirror = {};
  162.     local w = {};
  163.     for node, val in pairs(weights[bias]) do
  164.         w[getHex(node)] = val;
  165.     end
  166.     mirror[getHex(bias)] = {{}, 1, w};
  167.     w = nil;
  168.     for i = 1, m do
  169.         for j, node in ipairs(nodes[i]) do
  170.             local w = {};
  171.             for node, val in pairs(weights[node]) do
  172.                 w[getHex(node)] = val;
  173.             end
  174.             mirror[getHex(node)] = {{getIndex(node)}, values[node], w}
  175.         end
  176.     end
  177.    
  178.     return {
  179.         ['mirror'] = mirror;
  180.         ['input'] = function(...)
  181.             doInput{...};
  182.             local t = {};
  183.             for j, node in ipairs(nodes[m]) do
  184.                 table.insert(t, values[node]);
  185.             end
  186.             return unpack(t);
  187.         end;
  188.         ['learn'] = learn;
  189.         ['setWeight'] = function(hex, val)
  190.             local node = hexToNode[hex];
  191.             weights[node] = val;
  192.             node();
  193.             fire(1+getIndex(node));
  194.         end;
  195.         ['setBias'] = function(hex, val)
  196.             weights[bias][hexToNode[hex]] = val;
  197.         end;
  198.         ['randomize'] = function(nobias)
  199.             for n0, w in pairs(weights) do
  200.                 if nobias and (n0~=bias) then
  201.                     for n1, _ in pairs(w) do
  202.                         w[n1] = math.random();
  203.                     end
  204.                 end
  205.             end
  206.             updateMirror();
  207.         end;
  208.     };
  209. end;
Add Comment
Please, Sign In to add comment