Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- debugMode = true;
- if not debugMode then
- print = function()end;
- end
- local exp, max = math.exp, math.max;
- local K = function(z) return 1/(1+exp(-z)); end;
- -- multilayer perceptron neural network
- return function(N, E, dE)
- --[[
- N is an array whose elements are positive integers
- E is the error function
- dE is the total derivative of the error function
- ]]
- local m = #N;
- -- retrieve memory address of an object
- -- assigns names to nodes for identification
- local function getHex(o)
- return tonumber(tostring(o):sub(-8),16);
- end
- local nodes = {};
- local weights = {};
- local values = {}; -- value at each node
- local mirror; -- serialization
- local myInput;
- nodes[1] = {};
- for j = 1, N[1] do
- local node;
- node = function()
- values[node] = myInput[j];
- --print('firing node: ' .. 1 .. ', ' .. j .. ' @ ' .. values[node]);
- end;
- nodes[1][j] = node;
- weights[node] = {};
- end
- -- bias node
- local bias;
- bias = function()
- values[bias] = 1;
- end;
- bias();
- weights[bias] = {};
- for i = 2, m do
- nodes[i] = {};
- values[i] = {};
- for j = 1, N[i] do
- local last = nodes[i-1];
- local node;
- node = function()
- local sum = 0;
- for k, linked in ipairs(last) do
- sum = sum + weights[linked][node]*values[linked];
- end
- values[node] = K(sum+weights[bias][node]);
- --print('firing node: ' .. i .. ', ' .. j .. ' @ ' .. values[node]);
- end;
- nodes[i][j] = node;
- weights[node] = {};
- for k, linked in ipairs(last) do
- weights[linked][node] = 0;
- end
- weights[bias][node] = 0;
- end
- end
- local hexToNode = {}; -- get node from mem. address
- local index = {}; -- indices from object
- for i = 1, m do
- for j = 1, #nodes[i] do
- hexToNode[getHex(nodes[i][j])] = nodes[i][j];
- index[nodes[i][j]] = {i, j};
- end
- end
- hexToNode[getHex(bias)] = bias;
- local function getIndex(node)
- return unpack(index[node]);
- end
- local function updateMirror() -- updates serialization
- for hex, data in pairs(mirror) do
- local node = hexToNode[hex];
- data[2] = values[node];
- local w = data[3];
- for i, j in pairs(w) do
- w[i] = nil;
- end
- for node, val in pairs(weights[node]) do
- w[getHex(node)] = val;
- end
- end
- end
- local function fire(i) -- feedforward from layer 'i'
- i = i or 1;
- for j = i, m do
- for k, node in ipairs(nodes[j]) do
- node();
- end
- end
- updateMirror();
- end;
- local function doInput(X)
- myInput = X;
- fire(1);
- end
- local delta = {};
- for i = 1, m do
- delta[i] = {};
- for j = 1, #nodes[i] do
- delta[nodes[i][j]] = 0;
- end
- end
- delta[bias] = 0; -- always zero
- local function d(t, i) -- compute der. coefficients backwards to layer 'i'
- i = i or 1;
- local val;
- for k, node in ipairs(nodes[m]) do
- val = values[node];
- delta[node] = dE(t[k], val)*val*(1-val);
- end
- local sum;
- for j = m-1, i, -1 do
- for k, node in ipairs(nodes[j]) do
- sum = 0;
- for linked, w in pairs(weights[node]) do
- sum = sum + delta[linked]*w
- end
- delta[node] = sum*values[node]*(1-values[node]);
- end
- end
- end
- local function learn(input, target, eps, alpha, timeout) -- train with input and target
- doInput(input);
- local err = {};
- repeat
- d(target);
- for n0, myWeights in pairs(weights) do
- for n1, w in pairs(myWeights) do
- myWeights[n1] = w-alpha*delta[n1]*values[n0];
- end
- end
- fire(1);
- for i, node in ipairs(nodes[m]) do
- err[i] = E(target[i], values[node]);
- end
- until max(unpack(err))<eps;
- end
- mirror = {};
- local w = {};
- for node, val in pairs(weights[bias]) do
- w[getHex(node)] = val;
- end
- mirror[getHex(bias)] = {{}, 1, w};
- w = nil;
- for i = 1, m do
- for j, node in ipairs(nodes[i]) do
- local w = {};
- for node, val in pairs(weights[node]) do
- w[getHex(node)] = val;
- end
- mirror[getHex(node)] = {{getIndex(node)}, values[node], w}
- end
- end
- return {
- ['mirror'] = mirror;
- ['input'] = function(...)
- doInput{...};
- local t = {};
- for j, node in ipairs(nodes[m]) do
- table.insert(t, values[node]);
- end
- return unpack(t);
- end;
- ['learn'] = learn;
- ['setWeight'] = function(hex, val)
- local node = hexToNode[hex];
- weights[node] = val;
- node();
- fire(1+getIndex(node));
- end;
- ['setBias'] = function(hex, val)
- weights[bias][hexToNode[hex]] = val;
- end;
- ['randomize'] = function(nobias)
- for n0, w in pairs(weights) do
- if nobias and (n0~=bias) then
- for n1, _ in pairs(w) do
- w[n1] = math.random();
- end
- end
- end
- updateMirror();
- end;
- };
- end;
Add Comment
Please, Sign In to add comment