Advertisement
Terrah

Neural Network

Feb 19th, 2019
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 5.47 KB | None | 0 0
  1. NeuralNetwork  = {
  2.  
  3.     --Transfer = function( x) return 1 / (1 + math.exp(-x / 1)) end;
  4.     Transfer = function( x) return 1 / (1 + math.exp(-x)); end;
  5. };
  6.  
  7. function NeuralNetwork:Create(Inputs, Outputs, HiddenLayers, NeuronsPerLayer, LearningRate)
  8.  
  9.     assert(Inputs > 0, "Inputs must be higher then 0");
  10.     assert(Outputs > 0, "Outputs must be higher then 0");
  11.    
  12.     HiddenLayers = HiddenLayers or math.ceil(Inputs/2)
  13.     NeuronsPerLayer = NeuronsPerLayer or math.ceil(Inputs*0.66666+Outputs)
  14.     LearningRate = LearningRate or 0.5
  15.  
  16.     print(Inputs, Outputs, HiddenLayers, NeuronsPerLayer, LearningRate);
  17.    
  18.     local network = setmetatable({
  19.         learningRate = LearningRate,
  20.         numBackPropagates=0,
  21.     },{ __index = NeuralNetwork});
  22.    
  23.     table.insert(network, {});
  24.     for i = 1,Inputs do
  25.         table.insert(network[1], {});
  26.     end
  27.    
  28.     for i = 2,HiddenLayers+2 do --plus 2 represents the output layer (also need to skip input layer)
  29.  
  30.         table.insert(network, {});
  31.        
  32.         local neuronsInLayer = NeuronsPerLayer;
  33.  
  34.         if i == HiddenLayers+2 then
  35.  
  36.           neuronsInLayer = Outputs;
  37.  
  38.         end
  39.  
  40.         for j = 1,neuronsInLayer do
  41.  
  42.             table.insert(network[i], {bias = math.random()*2-1});
  43.          
  44.             local numNeuronInputs = #network[i-1];
  45.  
  46.             for k = 1,numNeuronInputs do
  47.            
  48.                 table.insert(network[i][j], math.random()*2-1);
  49.             end
  50.         end
  51.     end
  52.  
  53.     return network;
  54. end
  55.  
  56. --[[
  57.  
  58. CREATE TABLE `networks` (
  59.   `ID` varchar(64) NOT NULL,
  60.   `LearningRate` double DEFAULT NULL,
  61.   `NumBackPropagates` int(11) DEFAULT NULL,
  62.   PRIMARY KEY (`ID`)
  63. ) ENGINE=InnoDB;
  64.  
  65. CREATE TABLE `neurons` (
  66.   `ID` varchar(64) NOT NULL,
  67.   `LayerID` int(11) NOT NULL,
  68.   `NeuronID` int(11) NOT NULL,
  69.   `Bias` double DEFAULT NULL,
  70.   `Weights` longtext,
  71.   PRIMARY KEY (`ID`,`LayerID`,`NeuronID`),
  72.   KEY `layer_id` (`LayerID`),
  73.   KEY `neuron_id` (`NeuronID`)
  74. ) ENGINE=InnoDB;
  75.  
  76. ]]
  77.  
  78. function NeuralNetwork:Load(db, name)
  79.  
  80.     assert(db:Query("SELECT * FROM networks WHERE ID='"..db:EscapeString(name).."';"));
  81.     assert(db:Fetch(), "No neural network found with name "..tostring(name));
  82.    
  83.     local network = setmetatable({
  84.         learningRate = db:GetRow().LearningRate,
  85.         numBackPropagates = db:GetRow().NumBackPropagates
  86.     },{ __index = NeuralNetwork});
  87.    
  88.     assert(db:Query("SELECT MAX(LayerID) FROM neurons WHERE ID='"..db:EscapeString(name).."';"));
  89.     assert(db:Fetch(), "No neural network found with name "..tostring(name));
  90.     local layers = tonumber(db:GetRow(1));
  91.     assert(layers>=2, "No neural network found with name "..tostring(name));
  92.     local row;
  93.     local layer;
  94.     local neuron;
  95.    
  96.     for n=1,layers do
  97.    
  98.         layer = {};
  99.         table.insert(network, layer);
  100.    
  101.         assert(db:Query("SELECT * FROM neurons WHERE ID='"..db:EscapeString(name).."' AND LayerID="..n.." ORDER BY NeuronID;"));
  102.    
  103.         while db:Fetch() do
  104.             row = db:GetRow();
  105.             neuron = {
  106.                 bias = row.Bias
  107.             };
  108.            
  109.             for value in row.Weights:gmatch("|?(.-)|") do
  110.                 table.insert(neuron, tonumber(value));
  111.             end
  112.            
  113.             table.insert(layer, neuron);
  114.         end
  115.     end
  116.    
  117.     return network;
  118. end
  119.  
  120. function NeuralNetwork:Save(db, name)
  121.  
  122.     assert(db:Query("DELETE FROM networks WHERE ID='"..db:EscapeString(name).."';"));
  123.     assert(db:Query("DELETE FROM neurons WHERE ID='"..db:EscapeString(name).."';"));
  124.    
  125.     assert(db:Query("INSERT INTO networks (`ID`, `LearningRate`, `NumBackPropagates`) VALUES ('"..db:EscapeString(name).."', "..tostring(tonumber(self.learningRate))..", " ..tostring(tonumber(self.numBackPropagates))..");"));
  126.    
  127.     for l=1,#self do
  128.    
  129.         for n=1,#self[l] do
  130.        
  131.             local neruon = self[l][n];
  132.             local weights = "|";
  133.        
  134.             for w=1,#neruon do
  135.                 weights = weights .. neruon[w] .. "|";
  136.             end
  137.        
  138.             assert(db:Query("INSERT INTO `neurons` (`ID`, `LayerID`, `NeuronID`, `Bias`, `Weights`) VALUES ('"..db:EscapeString(name).."', "..l..", "..n..", "..tonumber(neruon.bias or 0)..", '"..db:EscapeString(weights).."');"));
  139.         end
  140.     end
  141. end
  142.  
  143. function NeuralNetwork:ForwardPropagate(inputs)
  144.  
  145.     assert(#inputs == #self[1], "Inputs does not match neural network inputs");
  146.  
  147.     self.numBackPropagates = self.numBackPropagates + 1;
  148.     local outputs = {}
  149.     for i = 1,#self do
  150.  
  151.         for j = 1,#self[i] do
  152.        
  153.             if i == 1 then
  154.                 self[i][j].result = inputs[j];     
  155.             else
  156.  
  157.                 self[i][j].result = self[i][j].bias;
  158.  
  159.                 for k = 1,#self[i][j] do
  160.  
  161.                     self[i][j].result = self[i][j].result + (self[i][j][k]*self[i-1][k].result);
  162.                 end
  163.  
  164.                 self[i][j].result = NeuralNetwork.Transfer(self[i][j].result);
  165.  
  166.                 if i == #self then
  167.                     table.insert(outputs,self[i][j].result);
  168.                 end
  169.             end
  170.         end
  171.     end
  172.  
  173.     return outputs;
  174. end
  175.  
  176. function NeuralNetwork:BackwardPropagate(inputs,outputs)
  177.  
  178.     assert(#inputs == #self[1], "Inputs does not match neural network input");
  179.     assert(#outputs == #self[#self], "Outputs does not match neural network input");
  180.    
  181.     self:ForwardPropagate(inputs);
  182.  
  183.     for i = #self,2,-1 do
  184.  
  185.         local tempResults = {};
  186.  
  187.         for j = 1,#self[i] do
  188.  
  189.             if i == #self then
  190.  
  191.                 self[i][j].delta = (outputs[j] - self[i][j].result) * self[i][j].result * (1 - self[i][j].result)
  192.  
  193.             else
  194.  
  195.                 local weightDelta = 0;
  196.  
  197.                 for k = 1,#self[i+1] do
  198.  
  199.                     weightDelta = weightDelta + self[i+1][k][j]*self[i+1][k].delta
  200.                 end
  201.  
  202.                 self[i][j].delta = self[i][j].result * (1 - self[i][j].result) * weightDelta
  203.             end
  204.         end
  205.     end
  206.  
  207.     for i = 2,#self do
  208.  
  209.         for j = 1,#self[i] do
  210.  
  211.             self[i][j].bias = self[i][j].delta * self.learningRate
  212.  
  213.             for k = 1,#self[i][j] do
  214.  
  215.                 self[i][j][k] = self[i][j][k] + self[i][j].delta * self.learningRate * self[i-1][k].result
  216.             end
  217.         end
  218.     end
  219. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement