Guest User

CC Neural Net

a guest
Feb 24th, 2025
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.03 KB | None | 0 0
  1. -- Obtain command-line arguments.
  2. local args = { ... }
  3. local training_mode = (args[1] == "training")
  4.  
  5. -- Configuration
  6. local WIDTH, HEIGHT = 100, 70 -- Drawing grid dimensions (cells)
  7. local NEURON_WIDTH, NEURON_HEIGHT = 20, 20 -- Normalized grid dimensions
  8. local DISPLAY_SCALE = 1 -- Not used for scaling now
  9. local LEARNING_RATE = 0.1
  10.  
  11. -- Grid drawing offsets (for border)
  12. local GRID_OFFSET_X = 2
  13. local GRID_OFFSET_Y = 2
  14.  
  15. -- Neural network configuration
  16. local INPUT_SIZE = NEURON_WIDTH * NEURON_HEIGHT -- 400 inputs
  17. local HIDDEN_SIZE = 50 -- Hidden layer size
  18. local OUTPUT_SIZE = 3 -- Three classes: square, circle, triangle
  19.  
  20. -- File for saving network parameters
  21. local NETWORK_FILE = "network_weights.txt"
  22.  
  23. ---------------------------------
  24. -- Neural Network Functions
  25. ---------------------------------
  26.  
  27. local network = {}
  28.  
  29. local function initialize_network()
  30. network = { weights1 = {}, bias1 = {},
  31. weights2 = {}, bias2 = {} }
  32. for i = 1, HIDDEN_SIZE do
  33. network.weights1[i] = {}
  34. for j = 1, INPUT_SIZE do
  35. network.weights1[i][j] = math.random() * 2 - 1
  36. end
  37. network.bias1[i] = math.random() * 2 - 1
  38. end
  39. for k = 1, OUTPUT_SIZE do
  40. network.weights2[k] = {}
  41. for i = 1, HIDDEN_SIZE do
  42. network.weights2[k][i] = math.random() * 2 - 1
  43. end
  44. network.bias2[k] = math.random() * 2 - 1
  45. end
  46. end
  47.  
  48. local function save_network()
  49. local file = fs.open(NETWORK_FILE, "w")
  50. file.write(textutils.serialize(network))
  51. file.close()
  52. end
  53.  
  54. local function load_network()
  55. local file = fs.open(NETWORK_FILE, "r")
  56. if file then
  57. network = textutils.unserialize(file.readAll())
  58. file.close()
  59. else
  60. initialize_network()
  61. end
  62. end
  63.  
  64. -- Activation functions
  65. local function sigmoid(x)
  66. return 1 / (1 + math.exp(-x))
  67. end
  68.  
  69. local function dsigmoid(y)
  70. return y * (1 - y) -- derivative assuming y = sigmoid(x)
  71. end
  72.  
  73. local function softmax(vec)
  74. local max_val = -math.huge
  75. for i, v in ipairs(vec) do
  76. if v > max_val then max_val = v end
  77. end
  78. local sum_exp = 0
  79. local exp_vec = {}
  80. for i, v in ipairs(vec) do
  81. exp_vec[i] = math.exp(v - max_val)
  82. sum_exp = sum_exp + exp_vec[i]
  83. end
  84. local out = {}
  85. for i, v in ipairs(exp_vec) do
  86. out[i] = v / sum_exp
  87. end
  88. return out
  89. end
  90.  
  91. -- Forward pass: returns hidden activations and output probabilities
  92. local function forward(input)
  93. local hidden = {}
  94. for i = 1, HIDDEN_SIZE do
  95. local sum = network.bias1[i]
  96. for j = 1, INPUT_SIZE do
  97. sum = sum + network.weights1[i][j] * input[j]
  98. end
  99. hidden[i] = sigmoid(sum)
  100. end
  101. local output_raw = {}
  102. for k = 1, OUTPUT_SIZE do
  103. local sum = network.bias2[k]
  104. for i = 1, HIDDEN_SIZE do
  105. sum = sum + network.weights2[k][i] * hidden[i]
  106. end
  107. output_raw[k] = sum
  108. end
  109. local output = softmax(output_raw)
  110. return hidden, output
  111. end
  112.  
  113. -- Train the network using one training sample (input vector and one-hot target)
  114. local function train_network(input, target, learning_rate)
  115. learning_rate = learning_rate or LEARNING_RATE
  116. local hidden, output = forward(input)
  117.  
  118. -- Compute output error (delta) for softmax with cross-entropy:
  119. local delta_output = {}
  120. for k = 1, OUTPUT_SIZE do
  121. delta_output[k] = output[k] - target[k]
  122. end
  123.  
  124. -- Update weights2 and bias2
  125. for k = 1, OUTPUT_SIZE do
  126. for i = 1, HIDDEN_SIZE do
  127. network.weights2[k][i] = network.weights2[k][i] - learning_rate * delta_output[k] * hidden[i]
  128. end
  129. network.bias2[k] = network.bias2[k] - learning_rate * delta_output[k]
  130. end
  131.  
  132. -- Backpropagate to hidden layer
  133. local delta_hidden = {}
  134. for i = 1, HIDDEN_SIZE do
  135. local sum = 0
  136. for k = 1, OUTPUT_SIZE do
  137. sum = sum + network.weights2[k][i] * delta_output[k]
  138. end
  139. delta_hidden[i] = sum * dsigmoid(hidden[i])
  140. end
  141. for i = 1, HIDDEN_SIZE do
  142. for j = 1, INPUT_SIZE do
  143. network.weights1[i][j] = network.weights1[i][j] - learning_rate * delta_hidden[i] * input[j]
  144. end
  145. network.bias1[i] = network.bias1[i] - learning_rate * delta_hidden[i]
  146. end
  147. save_network()
  148. end
  149.  
  150. -- Predict function: returns predicted class and probabilities
  151. local function predict(input)
  152. local _, output = forward(input)
  153. local max_val = -math.huge
  154. local max_index = 1
  155. for k, v in ipairs(output) do
  156. if v > max_val then
  157. max_val = v
  158. max_index = k
  159. end
  160. end
  161. local class = (max_index == 1 and "square") or (max_index == 2 and "circle") or (max_index == 3 and "triangle")
  162. return class, output
  163. end
  164.  
  165. -- Convert user shape string to one-hot target vector.
  166. local function target_vector(shape)
  167. if shape == "square" then
  168. return {1, 0, 0}
  169. elseif shape == "circle" then
  170. return {0, 1, 0}
  171. elseif shape == "triangle" then
  172. return {0, 0, 1}
  173. else
  174. return {0, 1, 0} -- default to circle if unrecognized
  175. end
  176. end
  177.  
  178. ---------------------------------
  179. -- Drawing Grid Functions
  180. ---------------------------------
  181.  
  182. local function clear_grid(w, h)
  183. local grid = {}
  184. for y = 1, h do
  185. grid[y] = {}
  186. for x = 1, w do
  187. grid[y][x] = 0
  188. end
  189. end
  190. return grid
  191. end
  192.  
  193. -- Global drawing grid (100x70)
  194. local grid = clear_grid(WIDTH, HEIGHT)
  195. local prev_grid = clear_grid(WIDTH, HEIGHT)
  196.  
  197. -- Draw a border around the grid.
  198. local function draw_border(offset_x, offset_y, width, height)
  199. term.setBackgroundColor(colors.gray)
  200. -- Top border
  201. term.setCursorPos(offset_x - 1, offset_y - 1)
  202. term.write(string.rep(" ", width + 2))
  203. -- Bottom border
  204. term.setCursorPos(offset_x - 1, offset_y + height)
  205. term.write(string.rep(" ", width + 2))
  206. -- Left and right borders
  207. for row = offset_y, offset_y + height - 1 do
  208. term.setCursorPos(offset_x - 1, row)
  209. term.write(" ")
  210. term.setCursorPos(offset_x + width, row)
  211. term.write(" ")
  212. end
  213. end
  214.  
  215. -- Draw a single cell for input grid.
  216. local function draw_cell_input(x, y)
  217. local screen_x = GRID_OFFSET_X + x - 1
  218. local screen_y = GRID_OFFSET_Y + y - 1
  219. term.setCursorPos(screen_x, screen_y)
  220. if grid[y][x] == 1 then
  221. term.setBackgroundColor(colors.lime)
  222. else
  223. term.setBackgroundColor(colors.black)
  224. end
  225. term.write(" ")
  226. end
  227.  
  228. -- Fully render the input drawing grid with border.
  229. local function draw_full_grid_input()
  230. term.setBackgroundColor(colors.black)
  231. term.clear()
  232. for y = 1, HEIGHT do
  233. for x = 1, WIDTH do
  234. draw_cell_input(x, y)
  235. end
  236. end
  237. draw_border(GRID_OFFSET_X, GRID_OFFSET_Y, WIDTH, HEIGHT)
  238. term.setCursorPos(1, GRID_OFFSET_Y + HEIGHT + 2)
  239. term.setBackgroundColor(colors.black)
  240. term.write("Draw with mouse. Press ENTER when done:")
  241. end
  242.  
  243. -- Capture drawing input from the user.
  244. local function draw_input()
  245. draw_full_grid_input()
  246. while true do
  247. local event, button, x, y = os.pullEvent()
  248. if (event == "mouse_click" or event == "mouse_drag") and x and y then
  249. -- Adjust for grid offset
  250. local gridX = x - GRID_OFFSET_X + 1
  251. local gridY = y - GRID_OFFSET_Y + 1
  252. if gridX >= 1 and gridX <= WIDTH and gridY >= 1 and gridY <= HEIGHT then
  253. if grid[gridY][gridX] ~= 1 then
  254. grid[gridY][gridX] = 1
  255. draw_cell_input(gridX, gridY)
  256. end
  257. end
  258. elseif event == "key" and button == keys.enter then
  259. break
  260. end
  261. end
  262. end
  263.  
  264. -- Detect the bounding box of the drawn shape.
  265. local function get_bounding_box(input_grid)
  266. local minX, maxX = WIDTH + 1, 0
  267. local minY, maxY = HEIGHT + 1, 0
  268. for y = 1, HEIGHT do
  269. for x = 1, WIDTH do
  270. if input_grid[y][x] == 1 then
  271. if x < minX then minX = x end
  272. if x > maxX then maxX = x end
  273. if y < minY then minY = y end
  274. if y > maxY then maxY = y end
  275. end
  276. end
  277. end
  278. if maxX < minX then
  279. return 1, WIDTH, 1, HEIGHT -- No drawing detected; default to full grid.
  280. end
  281. return minX, maxX, minY, maxY
  282. end
  283.  
  284. -- Normalize the drawn shape (within its bounding box) to a 20x20 grid.
  285. local function normalize_shape_to_20x20(input_grid)
  286. local minX, maxX, minY, maxY = get_bounding_box(input_grid)
  287. local norm = clear_grid(NEURON_WIDTH, NEURON_HEIGHT)
  288. local box_width = maxX - minX + 1
  289. local box_height = maxY - minY + 1
  290. for ny = 1, NEURON_HEIGHT do
  291. for nx = 1, NEURON_WIDTH do
  292. local srcX = minX + math.floor(((nx - 1) / (NEURON_WIDTH - 1)) * (box_width - 1) + 0.5)
  293. local srcY = minY + math.floor(((ny - 1) / (NEURON_HEIGHT - 1)) * (box_height - 1) + 0.5)
  294. norm[ny][nx] = input_grid[srcY][srcX]
  295. end
  296. end
  297. return norm
  298. end
  299.  
  300. -- Stretch a 20x20 grid back to the original 100x70 drawing area.
  301. local function stretch_norm_to_drawing(norm, targetWidth, targetHeight)
  302. local stretched = clear_grid(targetWidth, targetHeight)
  303. for y = 1, targetHeight do
  304. for x = 1, targetWidth do
  305. local srcX = math.floor((x - 1) / (targetWidth - 1) * (NEURON_WIDTH - 1)) + 1
  306. local srcY = math.floor((y - 1) / (targetHeight - 1) * (NEURON_HEIGHT - 1)) + 1
  307. stretched[y][x] = norm[srcY][srcX]
  308. end
  309. end
  310. return stretched
  311. end
  312.  
  313. -- Draw a grid (of size targetWidth x targetHeight) on screen with border.
  314. local function draw_stretched_grid_on_screen(aGrid, targetWidth, targetHeight)
  315. term.setBackgroundColor(colors.black)
  316. term.clear()
  317. for y = 1, targetHeight do
  318. for x = 1, targetWidth do
  319. local screen_x = GRID_OFFSET_X + x - 1
  320. local screen_y = GRID_OFFSET_Y + y - 1
  321. term.setCursorPos(screen_x, screen_y)
  322. if aGrid[y][x] == 1 then
  323. term.setBackgroundColor(colors.lime)
  324. else
  325. term.setBackgroundColor(colors.black)
  326. end
  327. term.write(" ")
  328. end
  329. end
  330. draw_border(GRID_OFFSET_X, GRID_OFFSET_Y, targetWidth, targetHeight)
  331. end
  332.  
  333. -- Flatten a 2D grid into a 1D vector.
  334. local function flatten_grid(two_d_grid)
  335. local vector = {}
  336. for y = 1, #two_d_grid do
  337. for x = 1, #two_d_grid[y] do
  338. table.insert(vector, two_d_grid[y][x])
  339. end
  340. end
  341. return vector
  342. end
  343.  
  344. ---------------------------------
  345. -- Main Loop
  346. ---------------------------------
  347.  
  348. -- Initialize or load network parameters.
  349. math.randomseed(os.time())
  350. load_network()
  351.  
  352. while true do
  353. -- Reset the drawing grid.
  354. grid = clear_grid(WIDTH, HEIGHT)
  355. prev_grid = clear_grid(WIDTH, HEIGHT)
  356. term.setBackgroundColor(colors.black)
  357. term.clear()
  358.  
  359. -- Capture user drawing input.
  360. draw_input()
  361.  
  362. -- Normalize the drawn shape to a fixed 20x20 grid.
  363. local normalized_grid = normalize_shape_to_20x20(grid)
  364.  
  365. -- Compute the network prediction using the flattened 20x20 input.
  366. local input_vector = flatten_grid(normalized_grid)
  367. local predicted_class, output_probs = predict(input_vector)
  368.  
  369. -- Stretch the normalized grid back to the original drawing area (100x70).
  370. local stretched_grid = stretch_norm_to_drawing(normalized_grid, WIDTH, HEIGHT)
  371. draw_stretched_grid_on_screen(stretched_grid, WIDTH, HEIGHT)
  372.  
  373. -- Display the network's prediction and probabilities.
  374. term.setCursorPos(1, GRID_OFFSET_Y + HEIGHT + 2)
  375. term.setBackgroundColor(colors.black)
  376. term.write(("Prediction: %s (Square: %.1f%%, Circle: %.1f%%, Triangle: %.1f%%)"):format(
  377. predicted_class,
  378. output_probs[1] * 100,
  379. output_probs[2] * 100,
  380. output_probs[3] * 100
  381. ))
  382.  
  383. if training_mode then
  384. -- In training mode, prompt for correct shape label and train.
  385. term.setCursorPos(1, GRID_OFFSET_Y + HEIGHT + 3)
  386. term.write("Enter shape name (square, circle, triangle): ")
  387. sleep(0.2) -- brief pause to flush events
  388. local shape = read()
  389. local target = target_vector(shape)
  390. train_network(input_vector, target, LEARNING_RATE)
  391. term.setCursorPos(1, GRID_OFFSET_Y + HEIGHT + 4)
  392. term.write("Training complete. Press ENTER to continue...")
  393. repeat
  394. local event, key = os.pullEvent("key")
  395. until key == keys.enter
  396. else
  397. -- In inference-only mode, do not train; wait for user to continue.
  398. term.setCursorPos(1, GRID_OFFSET_Y + HEIGHT + 3)
  399. term.write("Inference-only mode. Press ENTER to draw again...")
  400. repeat
  401. local event, key = os.pullEvent("key")
  402. until key == keys.enter
  403. end
  404. end
  405.  
Advertisement
Add Comment
Please, Sign In to add comment