Advertisement
SethBling

NeuralEvolve.lua

Mar 14th, 2015
1,083
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.54 KB | None | 0 0
  1. console.clear()
  2.  
  3. filename = "DP1.state"
  4. boxRadius = 6
  5. buttonNames = {
  6. "A",
  7. "B",
  8. "X",
  9. "Y",
  10. "Up",
  11. "Down",
  12. "Left",
  13. "Right",
  14. }
  15.  
  16. layerSizes = {30, 10, 10, 10, #buttonNames}
  17.  
  18.  
  19. function getTile(dx, dy)
  20. marioX = memory.read_s16_le(0x94)
  21. marioY = memory.read_s16_le(0x96)
  22.  
  23. x = math.floor((marioX+dx)/16)
  24. y = math.floor((marioY+dy)/16)
  25.  
  26. return memory.readbyte(0xC800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  27. end
  28.  
  29. function getSprites()
  30. local sprites = {}
  31. for slot=0,11 do
  32. local status = memory.readbyte(0x14C8+slot)
  33. if status ~= 0 then
  34. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  35. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  36. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  37. end
  38. end
  39.  
  40. return sprites
  41. end
  42.  
  43. function getExtendedSprites()
  44. local extended = {}
  45. for slot=0,11 do
  46. local number = memory.readbyte(0x170B+slot)
  47. if number ~= 0 then
  48. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  49. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  50. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  51. end
  52. end
  53.  
  54. return extended
  55. end
  56.  
  57. function getInputs()
  58. marioX = memory.read_s16_le(0x94)
  59. marioY = memory.read_s16_le(0x96)
  60.  
  61. sprites = getSprites()
  62. extended = getExtendedSprites()
  63.  
  64. local inputs = {}
  65.  
  66. for dy=-boxRadius*16,boxRadius*16,16 do
  67. for dx=-boxRadius*16,boxRadius*16,16 do
  68. inputs[#inputs+1] = 0
  69.  
  70. tile = getTile(dx, dy)
  71. if tile ~= 0x25 and marioY+dy < 0x1B0 then
  72. inputs[#inputs] = 1
  73. end
  74.  
  75. for i = 1,#sprites do
  76. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  77. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  78. if distx < 8 and disty < 8 then
  79. inputs[#inputs] = -1
  80. end
  81. end
  82.  
  83. for i = 1,#extended do
  84. distx = math.abs(extended[i]["x"] - (marioX+dx))
  85. disty = math.abs(extended[i]["y"] - (marioY+dy))
  86. if distx < 8 and disty < 8 then
  87. inputs[#inputs] = -1
  88. end
  89. end
  90. end
  91. end
  92.  
  93. mariovx = memory.read_s8(0x7B)
  94. mariovy = memory.read_s8(0x7D)
  95. inputs[#inputs+1] = mariovx / 70
  96. inputs[#inputs+1] = mariovy / 70
  97.  
  98. return inputs
  99. end
  100.  
  101.  
  102. function evaluate(inputs, chromosome)
  103. local layer = {}
  104. local prevLayer = inputs
  105. local c = 1
  106. for i=1,#layerSizes do
  107. layer = {}
  108. for n=1,layerSizes[i] do
  109. layer[n] = 0
  110. end
  111. for m=1,#layer do
  112. for n=1,#prevLayer do
  113. layer[m] = layer[m] + chromosome[c] * prevLayer[n]
  114. c = c + 1
  115. end
  116. layer[m] = math.atan(layer[m] + chromosome[c])
  117. c = c + 1
  118. end
  119. prevLayer = layer
  120. end
  121.  
  122. return layer
  123. end
  124.  
  125. function randomChromosome()
  126. local c = {}
  127.  
  128. inputs = getInputs()
  129. prevSize = #inputs
  130. for i=1,#layerSizes do
  131. for m=1,layerSizes[i] do
  132. for n=1,prevSize do
  133. c[#c+1] = math.random()*2-1
  134. end
  135. c[#c+1] = math.random()*2-1
  136. end
  137. prevSize = layerSizes[i]
  138. end
  139.  
  140. return c
  141. end
  142.  
  143. pool = {}
  144. for i=1,20 do
  145. pool[i] = {["chromosome"] = randomChromosome(), ["fitness"] = 0}
  146. end
  147.  
  148. currentChromosome = 1
  149.  
  150. function initializeRun()
  151. savestate.load(filename);
  152. rightmost = 0
  153. frame = 0
  154. timeout = 20
  155. end
  156.  
  157. function crossover(c1, c2)
  158. local c = {["chromosome"] = {}, ["fitness"] = 0}
  159. local pick = true
  160. for i=1,#c1["chromosome"] do
  161. if math.random(#c1["chromosome"]/4) == 1 then
  162. pick = not pick
  163. end
  164. if pick then
  165. c["chromosome"][i] = c1["chromosome"][i]
  166. else
  167. c["chromosome"][i] = c2["chromosome"][i]
  168. end
  169. end
  170.  
  171. return c
  172. end
  173.  
  174. function mutate(c)
  175. for i=1,#c["chromosome"] do
  176. if math.random(50) == 1 then
  177. c["chromosome"][i] = math.random()*2-1
  178. end
  179. end
  180. end
  181.  
  182. function createNewGeneration()
  183. table.sort(pool, function (a,b)
  184. return (a["fitness"] > b["fitness"])
  185. end)
  186.  
  187.  
  188. for i=((#pool)/2),(#pool) do
  189. c1 = pool[math.random(#pool/2)]
  190. c2 = pool[math.random(#pool/2)]
  191. pool[i] = crossover(c1, c2)
  192. mutate(pool[i])
  193. end
  194.  
  195. generation = generation + 1
  196. end
  197.  
  198. function clearJoypad()
  199. local controller = {}
  200. for b = 1,#buttonNames do
  201. controller["P1 " .. buttonNames[b]] = false
  202. end
  203. joypad.set(controller)
  204. end
  205.  
  206. function showTop()
  207. clearJoypad()
  208. currentChromosome = 1
  209. initializeRun()
  210. end
  211.  
  212. form = forms.newform(200, 142, "Fitness")
  213. maxFitnessLabel = forms.label(form, "Top Fitness: ", 5, 8)
  214. goButton = forms.button(form, "Show Top", showTop, 5, 30)
  215. showUI = forms.checkbox(form, "Show Inputs", 5, 52)
  216. inputsLabel = forms.label(form, "Inputs", 5, 74)
  217.  
  218. function onExit()
  219. forms.destroy(form)
  220. end
  221. event.onexit(onExit)
  222.  
  223. function connectionCost(chromosome)
  224. local total = 0
  225. for i=1,#chromosome["chromosome"] do
  226. c = chromosome["chromosome"][i]
  227. total = total + c*c
  228. end
  229.  
  230. return total
  231. end
  232.  
  233. generation = 0
  234. maxfitness = 0
  235. initializeRun()
  236.  
  237. while true do
  238. marioX = memory.read_s16_le(0x94)
  239. marioY = memory.read_s16_le(0x96)
  240.  
  241. timeoutBonus = frame / 4
  242. if timeout + timeoutBonus <= 0 then
  243. fitness = rightmost - frame / 10 - connectionCost(pool[currentChromosome])/10
  244. pool[currentChromosome]["fitness"] = fitness
  245.  
  246. if fitness > maxfitness then
  247. forms.settext(maxFitnessLabel, "Top Fitness: " .. math.floor(fitness))
  248. maxfitness = fitness
  249. end
  250.  
  251. console.writeline("Generation " .. generation .. " chromosome " .. currentChromosome .. " fitness: " .. math.floor(fitness))
  252. if currentChromosome == #pool then
  253. createNewGeneration()
  254. currentChromosome = #pool/2+1
  255. else
  256. currentChromosome = currentChromosome + 1
  257. end
  258. initializeRun()
  259. end
  260.  
  261. inputs = getInputs()
  262. if timeout + timeoutBonus > 2 and frame % 5 == 0 then
  263. outputs = evaluate(inputs, pool[currentChromosome]["chromosome"])
  264.  
  265. controller = {}
  266. inputsString = ""
  267. for n = 1,#buttonNames do
  268. if outputs[n] > 0 then
  269. controller["P1 " .. buttonNames[n]] = true
  270. inputsString = inputsString .. buttonNames[n]
  271. else
  272. controller["P1 " .. buttonNames[n]] = false
  273. end
  274. end
  275.  
  276. forms.settext(inputsLabel, inputsString)
  277. end
  278. joypad.set(controller)
  279.  
  280. if timeout + timeoutBonus <= 2 then
  281. clearJoypad()
  282. end
  283.  
  284. if marioX > rightmost then
  285. timeout = 20
  286. rightmost = marioX
  287. end
  288.  
  289. timeout = timeout - 1
  290. frame = frame + 1
  291.  
  292.  
  293. if forms.ischecked(showUI) then
  294. layer1x = memory.read_s16_le(0x1A);
  295. layer1y = memory.read_s16_le(0x1C);
  296.  
  297. for dy = 0,boxRadius*2 do
  298. for dx = 0,boxRadius*2 do
  299. input = inputs[dy*(boxRadius*2+1)+dx+1]
  300. gui.drawText(marioX+(dx-boxRadius)*16-layer1x,marioY+(dy-boxRadius)*16-layer1y,string.format("%i", input),0x80FFFFFF, 11)
  301. end
  302. end
  303. end
  304.  
  305. emu.frameadvance();
  306. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement