Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- math.randomseed(os.time()) -- Seed with the current system time
- -- Definition of Cubic B-Spline basis functions.
- function B0(t) return (1 - t)^3 / 6 end
- function B1(t) return (3*t^3 - 6*t^2 + 4) / 6 end
- function B2(t) return (-3*t^3 + 3*t^2 + 3*t + 1) / 6 end
- function B3(t) return t^3 / 6 end
- -- Compute the value of the j-th cubic B-spline basis function at a given x.
- function splineBasis(j, x)
- if j == 1 then
- return B0(x)
- elseif j == 2 then
- return B1(x)
- elseif j == 3 then
- return B2(x)
- elseif j == 4 then
- return B3(x)
- else
- error("Invalid basis function index")
- end
- end
- -- Evaluate the cubic spline at point x using specified knots and coefficients.
- function evaluateSpline(t, c, x)
- local n = #t - 1
- local i = 1
- while i < n and x > t[i+1] do
- i = i + 1
- end
- local localX = (x - t[i]) / (t[i+1] - t[i])
- return
- B0(localX) * (i-2 >= 1 and i-2 <= #c and c[i-2] or 0) +
- B1(localX) * (i-1 >= 1 and i-1 <= #c and c[i-1] or 0) +
- B2(localX) * (i >= 1 and i <= #c and c[i] or 0) +
- B3(localX) * (i+1 >= 1 and i+1 <= #c and c[i+1] or 0)
- end
- -- Node class definition.
- Node = {}
- Node.__index = Node
- -- Initialization of ADAM parameters inside the Node class constructor.
- function Node:new(numInputs)
- local o = setmetatable({}, self)
- o.coeffs = {} -- Store spline coefficients for each input connection
- o.lastInputs = {}
- o.m = {} -- First moment vector (mean)
- o.v = {} -- Second moment vector (uncentered variance)
- o.beta1 = 0.9
- o.beta2 = 0.999
- o.epsilon = 1e-8
- o.t = 0 -- Time step counter
- for i = 1, numInputs do
- o.coeffs[i] = {}
- o.lastInputs[i] = 0 -- Initialize last input storage
- o.m[i] = {}
- o.v[i] = {}
- for j = 1, 4 do
- o.coeffs[i][j] = math.random() * 2 - 1
- o.m[i][j] = 0
- o.v[i][j] = 0
- end
- end
- return o
- end
- function Node:evaluate(inputs, knots)
- local output = 0
- for i, input in ipairs(inputs) do
- self.lastInputs[i] = input -- Store last inputs for gradient calculation.
- output = output + evaluateSpline(knots, self.coeffs[i], input)
- end
- return output
- end
- function Node:getCoefficients()
- return self.coeffs
- end
- function Node:updateCoefficients(error, knots, learningRate, lambda)
- for i = 1, #self.lastInputs do
- local localX = (self.lastInputs[i] - knots[1]) / (knots[#knots] - knots[1])
- for j = 1, 4 do
- local grad = error * splineBasis(j, localX)
- -- Include the L2 regularization term in the gradient calculation.
- grad = grad + lambda * self.coeffs[i][j]
- self.coeffs[i][j] = self.coeffs[i][j] - learningRate * grad
- end
- end
- end
- -- Layer class definition.
- Layer = {}
- Layer.__index = Layer
- function Layer:new(numNodes, numInputsPerNode)
- local o = setmetatable({}, self)
- o.nodes = {}
- for i = 1, numNodes do
- o.nodes[i] = Node:new(numInputsPerNode)
- end
- return o
- end
- function Layer:getNodes()
- return self.nodes
- end
- function Layer:evaluate(inputs, knots)
- local outputs = {}
- for i, node in ipairs(self.nodes) do
- outputs[i] = node:evaluate(inputs, knots)
- end
- return outputs
- end
- function Layer:updateCoefficients(error, knots, learningRate)
- for _, node in ipairs(self.nodes) do
- node:updateCoefficients(error, knots, learningRate)
- end
- end
- -- KANetwork class definition.
- KANetwork = {}
- KANetwork.__index = KANetwork
- function KANetwork:new(layerSizes, numInputs)
- local o = setmetatable({}, self)
- o.layers = {}
- local inputs = numInputs
- for i, size in ipairs(layerSizes) do
- o.layers[i] = Layer:new(size, inputs)
- inputs = size
- end
- return o
- end
- function KANetwork:evaluate(inputs, knots)
- local outputs = inputs
- for _, layer in ipairs(self.layers) do
- outputs = layer:evaluate(outputs, knots)
- end
- return outputs
- end
- -- ADAM coefficients update function
- function Node:updateCoefficientsADAM(error, knots, learningRate)
- self.t = self.t + 1 -- Increment time step
- for i = 1, #self.lastInputs do
- local localX = (self.lastInputs[i] - knots[1]) / (knots[#knots] - knots[1])
- for j = 1, 4 do
- local grad = error * splineBasis(j, localX)
- -- ADAM update rules
- self.m[i][j] = self.beta1 * self.m[i][j] + (1 - self.beta1) * grad
- self.v[i][j] = self.beta2 * self.v[i][j] + (1 - self.beta2) * grad^2
- local m_hat = self.m[i][j] / (1 - self.beta1^self.t)
- local v_hat = self.v[i][j] / (1 - self.beta2^self.t)
- self.coeffs[i][j] = self.coeffs[i][j] - learningRate * m_hat / (math.sqrt(v_hat) + self.epsilon)
- end
- end
- end
- function KANetwork:train(inputs, outputs, knots, learningRate, epochs, lambda)
- for epoch = 1, epochs do
- local totalLoss = 0
- local epochOutputs = {} -- Store outputs for printing.
- for i, input in ipairs(inputs) do
- local predictedOutputs = self:evaluate(input, knots)
- local predicted = predictedOutputs[#predictedOutputs] -- Assume single output node at last layer.
- local error = predicted - outputs[i][1]
- totalLoss = totalLoss + 0.5 * error^2 + 0.5 * lambda * self:sumOfSquaredCoefficients() -- Include L2 loss in total loss.
- epochOutputs[#epochOutputs + 1] = predicted -- Collect outputs for this epoch.
- for _, layer in ipairs(self.layers) do
- for _, node in ipairs(layer:getNodes()) do
- node:updateCoefficientsADAM(error, knots, learningRate, lambda)
- end
- end
- end
- -- Print total loss and optionally the outputs every 100 epochs.
- if epoch % 100 == 0 then
- print(string.format("Epoch: %d, Total Loss: %.4f", epoch, totalLoss))
- print("Outputs at epoch " .. epoch .. ":")
- for i, output in ipairs(epochOutputs) do
- print(string.format("Input: {%d, %d}, Predicted: %.4f, True: %d", inputs[i][1], inputs[i][2], output, outputs[i][1]))
- end
- end
- end
- end
- function KANetwork:sumOfSquaredCoefficients()
- local sum = 0
- for _, layer in ipairs(self.layers) do
- for _, node in ipairs(layer:getNodes()) do
- for _, coeffs in ipairs(node:getCoefficients()) do
- for _, coeff in ipairs(coeffs) do
- sum = sum + coeff^2
- end
- end
- end
- end
- return sum
- end
- function KANetwork:getLayers()
- return self.layers
- end
- function printCoefficients(network)
- local layers = network:getLayers()
- for i, layer in ipairs(layers) do
- local nodes = layer:getNodes()
- for j, node in ipairs(nodes) do
- local coeffs = node:getCoefficients()
- local coeffsString = ""
- for _, c in ipairs(coeffs) do
- coeffsString = coeffsString .. "("
- for k, v in ipairs(c) do
- coeffsString = coeffsString .. string.format("%s%.2f", k > 1 and ", " or "", v)
- end
- coeffsString = coeffsString .. ")"
- end
- print(string.format("L%dN%d %s", i, j, coeffsString))
- end
- end
- end
- -- Example XOR problem configuration.
- local inputs = {{0, 0}, {0, 1}, {1, 0}, {1, 1}}
- local outputs = {{0}, {1}, {1}, {0}}
- local knots = {0.2, 0.5, 1.5, 2} -- Define the knot vector for splines.
- local myNetwork = KANetwork:new({3, 2}, 2) -- A simple 2-layer network suitable for XOR.
- -- Parameters for training.
- local learningRate = 0.0015
- local epochs = 10000
- local trainLambda = 0.001
- -- Train the network.
- myNetwork:train(inputs, outputs, knots, learningRate, epochs, trainLambda)
- printCoefficients(myNetwork)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement