Advertisement
Guest User

Lua KAN experiment with Adam

a guest
May 1st, 2024
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 7.99 KB | None | 0 0
  1. math.randomseed(os.time())  -- Seed with the current system time
  2.  
  3. -- Definition of Cubic B-Spline basis functions.
  4. function B0(t) return (1 - t)^3 / 6 end
  5. function B1(t) return (3*t^3 - 6*t^2 + 4) / 6 end
  6. function B2(t) return (-3*t^3 + 3*t^2 + 3*t + 1) / 6 end
  7. function B3(t) return t^3 / 6 end
  8.  
  9. -- Compute the value of the j-th cubic B-spline basis function at a given x.
  10. function splineBasis(j, x)
  11.     if j == 1 then
  12.         return B0(x)
  13.     elseif j == 2 then
  14.         return B1(x)
  15.     elseif j == 3 then
  16.         return B2(x)
  17.     elseif j == 4 then
  18.         return B3(x)
  19.     else
  20.         error("Invalid basis function index")
  21.     end
  22. end
  23.  
  24. -- Evaluate the cubic spline at point x using specified knots and coefficients.
  25. function evaluateSpline(t, c, x)
  26.     local n = #t - 1
  27.     local i = 1
  28.     while i < n and x > t[i+1] do
  29.         i = i + 1
  30.     end
  31.     local localX = (x - t[i]) / (t[i+1] - t[i])
  32.     return
  33.         B0(localX) * (i-2 >= 1 and i-2 <= #c and c[i-2] or 0) +
  34.         B1(localX) * (i-1 >= 1 and i-1 <= #c and c[i-1] or 0) +
  35.         B2(localX) * (i   >= 1 and i   <= #c and c[i]   or 0) +
  36.         B3(localX) * (i+1 >= 1 and i+1 <= #c and c[i+1] or 0)
  37. end
  38.  
  39. -- Node class definition.
  40. Node = {}
  41. Node.__index = Node
  42. -- Initialization of ADAM parameters inside the Node class constructor.
  43. function Node:new(numInputs)
  44.     local o = setmetatable({}, self)
  45.     o.coeffs = {}        -- Store spline coefficients for each input connection
  46.     o.lastInputs = {}
  47.     o.m = {}             -- First moment vector (mean)
  48.     o.v = {}             -- Second moment vector (uncentered variance)
  49.     o.beta1 = 0.9
  50.     o.beta2 = 0.999
  51.     o.epsilon = 1e-8
  52.     o.t = 0              -- Time step counter
  53.     for i = 1, numInputs do
  54.         o.coeffs[i] = {}
  55.         o.lastInputs[i] = 0  -- Initialize last input storage
  56.         o.m[i] = {}
  57.         o.v[i] = {}
  58.         for j = 1, 4 do
  59.             o.coeffs[i][j] = math.random() * 2 - 1
  60.             o.m[i][j] = 0
  61.             o.v[i][j] = 0
  62.         end
  63.     end
  64.     return o
  65. end
  66.  
  67. function Node:evaluate(inputs, knots)
  68.     local output = 0
  69.     for i, input in ipairs(inputs) do
  70.         self.lastInputs[i] = input  -- Store last inputs for gradient calculation.
  71.         output = output + evaluateSpline(knots, self.coeffs[i], input)
  72.     end
  73.     return output
  74. end
  75.  
  76. function Node:getCoefficients()
  77.     return self.coeffs
  78. end
  79.  
  80. function Node:updateCoefficients(error, knots, learningRate, lambda)
  81.     for i = 1, #self.lastInputs do
  82.         local localX = (self.lastInputs[i] - knots[1]) / (knots[#knots] - knots[1])
  83.         for j = 1, 4 do
  84.             local grad = error * splineBasis(j, localX)
  85.             -- Include the L2 regularization term in the gradient calculation.
  86.             grad = grad + lambda * self.coeffs[i][j]
  87.             self.coeffs[i][j] = self.coeffs[i][j] - learningRate * grad
  88.         end
  89.     end
  90. end
  91.  
  92. -- Layer class definition.
  93. Layer = {}
  94. Layer.__index = Layer
  95. function Layer:new(numNodes, numInputsPerNode)
  96.     local o = setmetatable({}, self)
  97.     o.nodes = {}
  98.     for i = 1, numNodes do
  99.         o.nodes[i] = Node:new(numInputsPerNode)
  100.     end
  101.     return o
  102. end
  103.  
  104. function Layer:getNodes()
  105.     return self.nodes
  106. end
  107.  
  108. function Layer:evaluate(inputs, knots)
  109.     local outputs = {}
  110.     for i, node in ipairs(self.nodes) do
  111.         outputs[i] = node:evaluate(inputs, knots)
  112.     end
  113.     return outputs
  114. end
  115.  
  116. function Layer:updateCoefficients(error, knots, learningRate)
  117.     for _, node in ipairs(self.nodes) do
  118.         node:updateCoefficients(error, knots, learningRate)
  119.     end
  120. end
  121.  
  122. -- KANetwork class definition.
  123. KANetwork = {}
  124. KANetwork.__index = KANetwork
  125. function KANetwork:new(layerSizes, numInputs)
  126.     local o = setmetatable({}, self)
  127.     o.layers = {}
  128.     local inputs = numInputs
  129.     for i, size in ipairs(layerSizes) do
  130.         o.layers[i] = Layer:new(size, inputs)
  131.         inputs = size
  132.     end
  133.     return o
  134. end
  135.  
  136. function KANetwork:evaluate(inputs, knots)
  137.     local outputs = inputs
  138.     for _, layer in ipairs(self.layers) do
  139.         outputs = layer:evaluate(outputs, knots)
  140.     end
  141.     return outputs
  142. end
  143.  
  144.  
  145. -- ADAM coefficients update function
  146. function Node:updateCoefficientsADAM(error, knots, learningRate)
  147.     self.t = self.t + 1  -- Increment time step
  148.     for i = 1, #self.lastInputs do
  149.         local localX = (self.lastInputs[i] - knots[1]) / (knots[#knots] - knots[1])
  150.         for j = 1, 4 do
  151.             local grad = error * splineBasis(j, localX)
  152.             -- ADAM update rules
  153.             self.m[i][j] = self.beta1 * self.m[i][j] + (1 - self.beta1) * grad
  154.             self.v[i][j] = self.beta2 * self.v[i][j] + (1 - self.beta2) * grad^2
  155.             local m_hat = self.m[i][j] / (1 - self.beta1^self.t)
  156.             local v_hat = self.v[i][j] / (1 - self.beta2^self.t)
  157.             self.coeffs[i][j] = self.coeffs[i][j] - learningRate * m_hat / (math.sqrt(v_hat) + self.epsilon)
  158.         end
  159.     end
  160. end
  161.  
  162.  
  163. function KANetwork:train(inputs, outputs, knots, learningRate, epochs, lambda)
  164.     for epoch = 1, epochs do
  165.         local totalLoss = 0
  166.         local epochOutputs = {}  -- Store outputs for printing.
  167.         for i, input in ipairs(inputs) do
  168.             local predictedOutputs = self:evaluate(input, knots)
  169.             local predicted = predictedOutputs[#predictedOutputs]  -- Assume single output node at last layer.
  170.             local error = predicted - outputs[i][1]
  171.             totalLoss = totalLoss + 0.5 * error^2 + 0.5 * lambda * self:sumOfSquaredCoefficients()  -- Include L2 loss in total loss.
  172.             epochOutputs[#epochOutputs + 1] = predicted  -- Collect outputs for this epoch.
  173.  
  174.             for _, layer in ipairs(self.layers) do
  175.                 for _, node in ipairs(layer:getNodes()) do
  176.                     node:updateCoefficientsADAM(error, knots, learningRate, lambda)
  177.                 end
  178.             end
  179.         end
  180.  
  181.         -- Print total loss and optionally the outputs every 100 epochs.
  182.         if epoch % 100 == 0 then
  183.             print(string.format("Epoch: %d, Total Loss: %.4f", epoch, totalLoss))
  184.             print("Outputs at epoch " .. epoch .. ":")
  185.             for i, output in ipairs(epochOutputs) do
  186.                 print(string.format("Input: {%d, %d}, Predicted: %.4f, True: %d", inputs[i][1], inputs[i][2], output, outputs[i][1]))
  187.             end
  188.         end
  189.     end
  190. end
  191.  
  192. function KANetwork:sumOfSquaredCoefficients()
  193.     local sum = 0
  194.     for _, layer in ipairs(self.layers) do
  195.         for _, node in ipairs(layer:getNodes()) do
  196.             for _, coeffs in ipairs(node:getCoefficients()) do
  197.                 for _, coeff in ipairs(coeffs) do
  198.                     sum = sum + coeff^2
  199.                 end
  200.             end
  201.         end
  202.     end
  203.     return sum
  204. end
  205.  
  206. function KANetwork:getLayers()
  207.     return self.layers
  208. end
  209.  
  210. function printCoefficients(network)
  211.     local layers = network:getLayers()
  212.     for i, layer in ipairs(layers) do
  213.         local nodes = layer:getNodes()
  214.         for j, node in ipairs(nodes) do
  215.             local coeffs = node:getCoefficients()
  216.             local coeffsString = ""
  217.             for _, c in ipairs(coeffs) do
  218.                 coeffsString = coeffsString .. "("
  219.                 for k, v in ipairs(c) do
  220.                     coeffsString = coeffsString .. string.format("%s%.2f", k > 1 and ", " or "", v)
  221.                 end
  222.                 coeffsString = coeffsString .. ")"
  223.             end
  224.             print(string.format("L%dN%d %s", i, j, coeffsString))
  225.         end
  226.     end
  227. end
  228.  
  229. -- Example XOR problem configuration.
  230. local inputs = {{0, 0}, {0, 1}, {1, 0}, {1, 1}}
  231. local outputs = {{0}, {1}, {1}, {0}}
  232. local knots = {0.2, 0.5, 1.5, 2}  -- Define the knot vector for splines.
  233. local myNetwork = KANetwork:new({3, 2}, 2)  -- A simple 2-layer network suitable for XOR.
  234.  
  235. -- Parameters for training.
  236. local learningRate = 0.0015
  237. local epochs = 10000
  238. local trainLambda = 0.001
  239.  
  240. -- Train the network.
  241. myNetwork:train(inputs, outputs, knots, learningRate, epochs, trainLambda)
  242. printCoefficients(myNetwork)
  243.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement