Advertisement
Code-Akisame

MarI/0 for Einfach Nerdig with a fix for loops

Mar 11th, 2018
1,391
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 40.05 KB | None | 0 0
  1. -- MarI/O by SethBling with some tweaks by Akisame
  2. -- I would like to stress that we tried to keep the original code intact as much as possible.
  3. -- The tweaks by Akisame and Anonymous to fix the pit bug are in line 49, 1237-1252
  4. -- Further tweaks by Akisame can be found in line 48-56, 68, 86-92, 231, 831, 860, 1076, 1191, 1194-1197, 1214-1263, 1271-1276, 1290-1292, 1298-1304, 1321 and 1324
  5. -- The loading bug fix by Anonymous can be found in line 1108-1120
  6. -- These tweaks changes the fontsize of the banner and add a counter to check how many genomes reached maxFitness - 30
  7. -- They also fix the pipe fitness bug (fitness freezing when entering a pipe) as well as the fitness drop bug (hard fitness drop after pipe) and freeze the fitness when it detects mario walking into a loop.
  8.  
  9. -- Feel free to use this code but please link back to SethBling and this code, but please do not redistribute it.
  10. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  11. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  12. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  13. if gameinfo.getromname() == "Super Mario World (USA)" then
  14. Filename = "DP1.state"
  15. ButtonNames = {
  16. "A",
  17. "B",
  18. "X",
  19. "Y",
  20. "Up",
  21. "Down",
  22. "Left",
  23. "Right",
  24. }
  25. elseif gameinfo.getromname() == "Super Mario Bros." then
  26. Filename = "SMB1-1.state"
  27. ButtonNames = {
  28. "A",
  29. "B",
  30. "Up",
  31. "Down",
  32. "Left",
  33. "Right",
  34. }
  35. end
  36.  
  37. BoxRadius = 6
  38. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  39.  
  40. Inputs = InputSize+1
  41. Outputs = #ButtonNames
  42.  
  43. Population = 300
  44. DeltaDisjoint = 2.0
  45. DeltaWeights = 0.4
  46. DeltaThreshold = 1.0
  47.  
  48. StaleSpecies = 25 --increased slightly to allow for longer evolution
  49. mariohole = false --flag for mario falling into a hole
  50. mariopipe = false --flag for mario exiting a pipe
  51. marioPipeEnter = false --new flag for entering a pipe
  52. offset = 0 --offset value for when mario exits a pipe
  53. timebonus = 0 --timebonus to prevent the hard drops in fitness when exiting a pipe
  54. maxcounter = 0 --species that reach the max fitness
  55. currentTracker = 0 --tracks location on screen
  56. loop = false --flag for when a loop is detected
  57.  
  58. MutateConnectionsChance = 0.25
  59. PerturbChance = 0.90
  60. CrossoverChance = 0.75
  61. LinkMutationChance = 2.0
  62. NodeMutationChance = 0.50
  63. BiasMutationChance = 0.40
  64. StepSize = 0.1
  65. DisableMutationChance = 0.4
  66. EnableMutationChance = 0.2
  67.  
  68. TimeoutConstant = 90 --set to a higher value but see for yourself what is acceptable
  69.  
  70. MaxNodes = 1000000
  71.  
  72. function getPositions()
  73. if gameinfo.getromname() == "Super Mario World (USA)" then
  74. marioX = memory.read_s16_le(0x94)
  75. marioY = memory.read_s16_le(0x96)
  76.  
  77.  
  78. local layer1x = memory.read_s16_le(0x1A);
  79. local layer1y = memory.read_s16_le(0x1C);
  80.  
  81. screenX = marioX-layer1x
  82. screenY = marioY-layer1y
  83. elseif gameinfo.getromname() == "Super Mario Bros." then
  84. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  85. marioY = memory.readbyte(0x03B8)+16
  86. mariostate = memory.readbyte(0x0E) --get mario's state. (entering exiting pipes, dying, picking up a mushroom etc etc)
  87.  
  88. currentscreen = memory.readbyte(0x071A) --finds current screen for loop detection
  89. nextscreen = memory.readbyte(0x071B) --finds next screen
  90.  
  91. CurrentWorld = memory.readbyte(0x075F) --finds the current world (value + 1 is world)
  92. CurrentLevel = memory.readbyte(0x0760) --finds the current level (0=1 2=2 3=3 4=4)
  93.  
  94. screenX = memory.readbyte(0x03AD)
  95. screenY = memory.readbyte(0x03B8)
  96. end
  97. end
  98.  
  99. function getTile(dx, dy)
  100. if gameinfo.getromname() == "Super Mario World (USA)" then
  101. x = math.floor((marioX+dx+8)/16)
  102. y = math.floor((marioY+dy)/16)
  103.  
  104. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  105. elseif gameinfo.getromname() == "Super Mario Bros." then
  106. local x = marioX + dx + 8
  107. local y = marioY + dy - 16
  108. local page = math.floor(x/256)%2
  109.  
  110. local subx = math.floor((x%256)/16)
  111. local suby = math.floor((y - 32)/16)
  112. local addr = 0x500 + page*13*16+suby*16+subx
  113.  
  114. if suby >= 13 or suby < 0 then
  115. return 0
  116. end
  117.  
  118. if memory.readbyte(addr) ~= 0 then
  119. return 1
  120. else
  121. return 0
  122. end
  123. end
  124. end
  125.  
  126. function getSprites()
  127. if gameinfo.getromname() == "Super Mario World (USA)" then
  128. local sprites = {}
  129. for slot=0,11 do
  130. local status = memory.readbyte(0x14C8+slot)
  131. if status ~= 0 then
  132. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  133. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  134. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  135. end
  136. end
  137.  
  138. return sprites
  139. elseif gameinfo.getromname() == "Super Mario Bros." then
  140. local sprites = {}
  141. for slot=0,4 do
  142. local enemy = memory.readbyte(0xF+slot)
  143. if enemy ~= 0 then
  144. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  145. local ey = memory.readbyte(0xCF + slot)+24
  146. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  147. end
  148. end
  149.  
  150. return sprites
  151. end
  152. end
  153.  
  154. function getExtendedSprites()
  155. if gameinfo.getromname() == "Super Mario World (USA)" then
  156. local extended = {}
  157. for slot=0,11 do
  158. local number = memory.readbyte(0x170B+slot)
  159. if number ~= 0 then
  160. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  161. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  162. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  163. end
  164. end
  165.  
  166. return extended
  167. elseif gameinfo.getromname() == "Super Mario Bros." then
  168. return {}
  169. end
  170. end
  171.  
  172. function getInputs()
  173. getPositions()
  174.  
  175. sprites = getSprites()
  176. extended = getExtendedSprites()
  177.  
  178. local inputs = {}
  179.  
  180. for dy=-BoxRadius*16,BoxRadius*16,16 do
  181. for dx=-BoxRadius*16,BoxRadius*16,16 do
  182. inputs[#inputs+1] = 0
  183.  
  184. tile = getTile(dx, dy)
  185. if tile == 1 and marioY+dy < 0x1B0 then
  186. inputs[#inputs] = 1
  187. end
  188.  
  189. for i = 1,#sprites do
  190. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  191. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  192. if distx <= 8 and disty <= 8 then
  193. inputs[#inputs] = -1
  194. end
  195. end
  196.  
  197. for i = 1,#extended do
  198. distx = math.abs(extended[i]["x"] - (marioX+dx))
  199. disty = math.abs(extended[i]["y"] - (marioY+dy))
  200. if distx < 8 and disty < 8 then
  201. inputs[#inputs] = -1
  202. end
  203. end
  204. end
  205. end
  206.  
  207. --mariovx = memory.read_s8(0x7B)
  208. --mariovy = memory.read_s8(0x7D)
  209.  
  210. return inputs
  211. end
  212.  
  213. function sigmoid(x)
  214. return 2/(1+math.exp(-4.9*x))-1
  215. end
  216.  
  217. function newInnovation()
  218. pool.innovation = pool.innovation + 1
  219. return pool.innovation
  220. end
  221.  
  222. function newPool()
  223. local pool = {}
  224. pool.species = {}
  225. pool.generation = 0
  226. pool.innovation = Outputs
  227. pool.currentSpecies = 1
  228. pool.currentGenome = 1
  229. pool.currentFrame = 0
  230. pool.maxFitness = 0
  231. pool.maxcounter = 0 --counter for %max species reached
  232.  
  233. return pool
  234. end
  235.  
  236. function newSpecies()
  237. local species = {}
  238. species.topFitness = 0
  239. species.staleness = 0
  240. species.genomes = {}
  241. species.averageFitness = 0
  242.  
  243. return species
  244. end
  245.  
  246. function newGenome()
  247. local genome = {}
  248. genome.genes = {}
  249. genome.fitness = 0
  250. genome.adjustedFitness = 0
  251. genome.network = {}
  252. genome.maxneuron = 0
  253. genome.globalRank = 0
  254. genome.mutationRates = {}
  255. genome.mutationRates["connections"] = MutateConnectionsChance
  256. genome.mutationRates["link"] = LinkMutationChance
  257. genome.mutationRates["bias"] = BiasMutationChance
  258. genome.mutationRates["node"] = NodeMutationChance
  259. genome.mutationRates["enable"] = EnableMutationChance
  260. genome.mutationRates["disable"] = DisableMutationChance
  261. genome.mutationRates["step"] = StepSize
  262.  
  263. return genome
  264. end
  265.  
  266. function copyGenome(genome)
  267. local genome2 = newGenome()
  268. for g=1,#genome.genes do
  269. table.insert(genome2.genes, copyGene(genome.genes[g]))
  270. end
  271. genome2.maxneuron = genome.maxneuron
  272. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  273. genome2.mutationRates["link"] = genome.mutationRates["link"]
  274. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  275. genome2.mutationRates["node"] = genome.mutationRates["node"]
  276. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  277. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  278.  
  279. return genome2
  280. end
  281.  
  282. function basicGenome()
  283. local genome = newGenome()
  284. local innovation = 1
  285.  
  286. genome.maxneuron = Inputs
  287. mutate(genome)
  288.  
  289. return genome
  290. end
  291.  
  292. function newGene()
  293. local gene = {}
  294. gene.into = 0
  295. gene.out = 0
  296. gene.weight = 0.0
  297. gene.enabled = true
  298. gene.innovation = 0
  299.  
  300. return gene
  301. end
  302.  
  303. function copyGene(gene)
  304. local gene2 = newGene()
  305. gene2.into = gene.into
  306. gene2.out = gene.out
  307. gene2.weight = gene.weight
  308. gene2.enabled = gene.enabled
  309. gene2.innovation = gene.innovation
  310.  
  311. return gene2
  312. end
  313.  
  314. function newNeuron()
  315. local neuron = {}
  316. neuron.incoming = {}
  317. neuron.value = 0.0
  318.  
  319. return neuron
  320. end
  321.  
  322. function generateNetwork(genome)
  323. local network = {}
  324. network.neurons = {}
  325.  
  326. for i=1,Inputs do
  327. network.neurons[i] = newNeuron()
  328. end
  329.  
  330. for o=1,Outputs do
  331. network.neurons[MaxNodes+o] = newNeuron()
  332. end
  333.  
  334. table.sort(genome.genes, function (a,b)
  335. return (a.out < b.out)
  336. end)
  337. for i=1,#genome.genes do
  338. local gene = genome.genes[i]
  339. if gene.enabled then
  340. if network.neurons[gene.out] == nil then
  341. network.neurons[gene.out] = newNeuron()
  342. end
  343. local neuron = network.neurons[gene.out]
  344. table.insert(neuron.incoming, gene)
  345. if network.neurons[gene.into] == nil then
  346. network.neurons[gene.into] = newNeuron()
  347. end
  348. end
  349. end
  350.  
  351. genome.network = network
  352. end
  353.  
  354. function evaluateNetwork(network, inputs)
  355. table.insert(inputs, 1)
  356. if #inputs ~= Inputs then
  357. console.writeline("Incorrect number of neural network inputs.")
  358. return {}
  359. end
  360.  
  361. for i=1,Inputs do
  362. network.neurons[i].value = inputs[i]
  363. end
  364.  
  365. for _,neuron in pairs(network.neurons) do
  366. local sum = 0
  367. for j = 1,#neuron.incoming do
  368. local incoming = neuron.incoming[j]
  369. local other = network.neurons[incoming.into]
  370. sum = sum + incoming.weight * other.value
  371. end
  372.  
  373. if #neuron.incoming > 0 then
  374. neuron.value = sigmoid(sum)
  375. end
  376. end
  377.  
  378. local outputs = {}
  379. for o=1,Outputs do
  380. local button = "P1 " .. ButtonNames[o]
  381. if network.neurons[MaxNodes+o].value > 0 then
  382. outputs[button] = true
  383. else
  384. outputs[button] = false
  385. end
  386. end
  387.  
  388. return outputs
  389. end
  390.  
  391. function crossover(g1, g2)
  392. -- Make sure g1 is the higher fitness genome
  393. if g2.fitness > g1.fitness then
  394. tempg = g1
  395. g1 = g2
  396. g2 = tempg
  397. end
  398.  
  399. local child = newGenome()
  400.  
  401. local innovations2 = {}
  402. for i=1,#g2.genes do
  403. local gene = g2.genes[i]
  404. innovations2[gene.innovation] = gene
  405. end
  406.  
  407. for i=1,#g1.genes do
  408. local gene1 = g1.genes[i]
  409. local gene2 = innovations2[gene1.innovation]
  410. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  411. table.insert(child.genes, copyGene(gene2))
  412. else
  413. table.insert(child.genes, copyGene(gene1))
  414. end
  415. end
  416.  
  417. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  418.  
  419. for mutation,rate in pairs(g1.mutationRates) do
  420. child.mutationRates[mutation] = rate
  421. end
  422.  
  423. return child
  424. end
  425.  
  426. function randomNeuron(genes, nonInput)
  427. local neurons = {}
  428. if not nonInput then
  429. for i=1,Inputs do
  430. neurons[i] = true
  431. end
  432. end
  433. for o=1,Outputs do
  434. neurons[MaxNodes+o] = true
  435. end
  436. for i=1,#genes do
  437. if (not nonInput) or genes[i].into > Inputs then
  438. neurons[genes[i].into] = true
  439. end
  440. if (not nonInput) or genes[i].out > Inputs then
  441. neurons[genes[i].out] = true
  442. end
  443. end
  444.  
  445. local count = 0
  446. for _,_ in pairs(neurons) do
  447. count = count + 1
  448. end
  449. local n = math.random(1, count)
  450.  
  451. for k,v in pairs(neurons) do
  452. n = n-1
  453. if n == 0 then
  454. return k
  455. end
  456. end
  457.  
  458. return 0
  459. end
  460.  
  461. function containsLink(genes, link)
  462. for i=1,#genes do
  463. local gene = genes[i]
  464. if gene.into == link.into and gene.out == link.out then
  465. return true
  466. end
  467. end
  468. end
  469.  
  470. function pointMutate(genome)
  471. local step = genome.mutationRates["step"]
  472.  
  473. for i=1,#genome.genes do
  474. local gene = genome.genes[i]
  475. if math.random() < PerturbChance then
  476. gene.weight = gene.weight + math.random() * step*2 - step
  477. else
  478. gene.weight = math.random()*4-2
  479. end
  480. end
  481. end
  482.  
  483. function linkMutate(genome, forceBias)
  484. local neuron1 = randomNeuron(genome.genes, false)
  485. local neuron2 = randomNeuron(genome.genes, true)
  486.  
  487. local newLink = newGene()
  488. if neuron1 <= Inputs and neuron2 <= Inputs then
  489. --Both input nodes
  490. return
  491. end
  492. if neuron2 <= Inputs then
  493. -- Swap output and input
  494. local temp = neuron1
  495. neuron1 = neuron2
  496. neuron2 = temp
  497. end
  498.  
  499. newLink.into = neuron1
  500. newLink.out = neuron2
  501. if forceBias then
  502. newLink.into = Inputs
  503. end
  504.  
  505. if containsLink(genome.genes, newLink) then
  506. return
  507. end
  508. newLink.innovation = newInnovation()
  509. newLink.weight = math.random()*4-2
  510.  
  511. table.insert(genome.genes, newLink)
  512. end
  513.  
  514. function nodeMutate(genome)
  515. if #genome.genes == 0 then
  516. return
  517. end
  518.  
  519. genome.maxneuron = genome.maxneuron + 1
  520.  
  521. local gene = genome.genes[math.random(1,#genome.genes)]
  522. if not gene.enabled then
  523. return
  524. end
  525. gene.enabled = false
  526.  
  527. local gene1 = copyGene(gene)
  528. gene1.out = genome.maxneuron
  529. gene1.weight = 1.0
  530. gene1.innovation = newInnovation()
  531. gene1.enabled = true
  532. table.insert(genome.genes, gene1)
  533.  
  534. local gene2 = copyGene(gene)
  535. gene2.into = genome.maxneuron
  536. gene2.innovation = newInnovation()
  537. gene2.enabled = true
  538. table.insert(genome.genes, gene2)
  539. end
  540.  
  541. function enableDisableMutate(genome, enable)
  542. local candidates = {}
  543. for _,gene in pairs(genome.genes) do
  544. if gene.enabled == not enable then
  545. table.insert(candidates, gene)
  546. end
  547. end
  548.  
  549. if #candidates == 0 then
  550. return
  551. end
  552.  
  553. local gene = candidates[math.random(1,#candidates)]
  554. gene.enabled = not gene.enabled
  555. end
  556.  
  557. function mutate(genome)
  558. for mutation,rate in pairs(genome.mutationRates) do
  559. if math.random(1,2) == 1 then
  560. genome.mutationRates[mutation] = 0.95*rate
  561. else
  562. genome.mutationRates[mutation] = 1.05263*rate
  563. end
  564. end
  565.  
  566. if math.random() < genome.mutationRates["connections"] then
  567. pointMutate(genome)
  568. end
  569.  
  570. local p = genome.mutationRates["link"]
  571. while p > 0 do
  572. if math.random() < p then
  573. linkMutate(genome, false)
  574. end
  575. p = p - 1
  576. end
  577.  
  578. p = genome.mutationRates["bias"]
  579. while p > 0 do
  580. if math.random() < p then
  581. linkMutate(genome, true)
  582. end
  583. p = p - 1
  584. end
  585.  
  586. p = genome.mutationRates["node"]
  587. while p > 0 do
  588. if math.random() < p then
  589. nodeMutate(genome)
  590. end
  591. p = p - 1
  592. end
  593.  
  594. p = genome.mutationRates["enable"]
  595. while p > 0 do
  596. if math.random() < p then
  597. enableDisableMutate(genome, true)
  598. end
  599. p = p - 1
  600. end
  601.  
  602. p = genome.mutationRates["disable"]
  603. while p > 0 do
  604. if math.random() < p then
  605. enableDisableMutate(genome, false)
  606. end
  607. p = p - 1
  608. end
  609. end
  610.  
  611. function disjoint(genes1, genes2)
  612. local i1 = {}
  613. for i = 1,#genes1 do
  614. local gene = genes1[i]
  615. i1[gene.innovation] = true
  616. end
  617.  
  618. local i2 = {}
  619. for i = 1,#genes2 do
  620. local gene = genes2[i]
  621. i2[gene.innovation] = true
  622. end
  623.  
  624. local disjointGenes = 0
  625. for i = 1,#genes1 do
  626. local gene = genes1[i]
  627. if not i2[gene.innovation] then
  628. disjointGenes = disjointGenes+1
  629. end
  630. end
  631.  
  632. for i = 1,#genes2 do
  633. local gene = genes2[i]
  634. if not i1[gene.innovation] then
  635. disjointGenes = disjointGenes+1
  636. end
  637. end
  638.  
  639. local n = math.max(#genes1, #genes2)
  640.  
  641. return disjointGenes / n
  642. end
  643.  
  644. function weights(genes1, genes2)
  645. local i2 = {}
  646. for i = 1,#genes2 do
  647. local gene = genes2[i]
  648. i2[gene.innovation] = gene
  649. end
  650.  
  651. local sum = 0
  652. local coincident = 0
  653. for i = 1,#genes1 do
  654. local gene = genes1[i]
  655. if i2[gene.innovation] ~= nil then
  656. local gene2 = i2[gene.innovation]
  657. sum = sum + math.abs(gene.weight - gene2.weight)
  658. coincident = coincident + 1
  659. end
  660. end
  661.  
  662. return sum / coincident
  663. end
  664.  
  665. function sameSpecies(genome1, genome2)
  666. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  667. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  668. return dd + dw < DeltaThreshold
  669. end
  670.  
  671. function rankGlobally()
  672. local global = {}
  673. for s = 1,#pool.species do
  674. local species = pool.species[s]
  675. for g = 1,#species.genomes do
  676. table.insert(global, species.genomes[g])
  677. end
  678. end
  679. table.sort(global, function (a,b)
  680. return (a.fitness < b.fitness)
  681. end)
  682.  
  683. for g=1,#global do
  684. global[g].globalRank = g
  685. end
  686. end
  687.  
  688. function calculateAverageFitness(species)
  689. local total = 0
  690.  
  691. for g=1,#species.genomes do
  692. local genome = species.genomes[g]
  693. total = total + genome.globalRank
  694. end
  695.  
  696. species.averageFitness = total / #species.genomes
  697. end
  698.  
  699. function totalAverageFitness()
  700. local total = 0
  701. for s = 1,#pool.species do
  702. local species = pool.species[s]
  703. total = total + species.averageFitness
  704. end
  705.  
  706. return total
  707. end
  708.  
  709. function cullSpecies(cutToOne)
  710. for s = 1,#pool.species do
  711. local species = pool.species[s]
  712.  
  713. table.sort(species.genomes, function (a,b)
  714. return (a.fitness > b.fitness)
  715. end)
  716.  
  717. local remaining = math.ceil(#species.genomes/2)
  718. if cutToOne then
  719. remaining = 1
  720. end
  721. while #species.genomes > remaining do
  722. table.remove(species.genomes)
  723. end
  724. end
  725. end
  726.  
  727. function breedChild(species)
  728. local child = {}
  729. if math.random() < CrossoverChance then
  730. g1 = species.genomes[math.random(1, #species.genomes)]
  731. g2 = species.genomes[math.random(1, #species.genomes)]
  732. child = crossover(g1, g2)
  733. else
  734. g = species.genomes[math.random(1, #species.genomes)]
  735. child = copyGenome(g)
  736. end
  737.  
  738. mutate(child)
  739.  
  740. return child
  741. end
  742.  
  743. function removeStaleSpecies()
  744. local survived = {}
  745.  
  746. for s = 1,#pool.species do
  747. local species = pool.species[s]
  748.  
  749. table.sort(species.genomes, function (a,b)
  750. return (a.fitness > b.fitness)
  751. end)
  752.  
  753. if species.genomes[1].fitness > species.topFitness then
  754. species.topFitness = species.genomes[1].fitness
  755. species.staleness = 0
  756. else
  757. species.staleness = species.staleness + 1
  758. end
  759. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  760. table.insert(survived, species)
  761. end
  762. end
  763.  
  764. pool.species = survived
  765. end
  766.  
  767. function removeWeakSpecies()
  768. local survived = {}
  769.  
  770. local sum = totalAverageFitness()
  771. for s = 1,#pool.species do
  772. local species = pool.species[s]
  773. breed = math.floor(species.averageFitness / sum * Population)
  774. if breed >= 1 then
  775. table.insert(survived, species)
  776. end
  777. end
  778.  
  779. pool.species = survived
  780. end
  781.  
  782.  
  783. function addToSpecies(child)
  784. local foundSpecies = false
  785. for s=1,#pool.species do
  786. local species = pool.species[s]
  787. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  788. table.insert(species.genomes, child)
  789. foundSpecies = true
  790. end
  791. end
  792.  
  793. if not foundSpecies then
  794. local childSpecies = newSpecies()
  795. table.insert(childSpecies.genomes, child)
  796. table.insert(pool.species, childSpecies)
  797. end
  798. end
  799.  
  800. function newGeneration()
  801.  
  802. cullSpecies(false) -- Cull the bottom half of each species
  803. rankGlobally()
  804. removeStaleSpecies()
  805. rankGlobally()
  806. for s = 1,#pool.species do
  807. local species = pool.species[s]
  808. calculateAverageFitness(species)
  809. end
  810. removeWeakSpecies()
  811. local sum = totalAverageFitness()
  812. local children = {}
  813. for s = 1,#pool.species do
  814. local species = pool.species[s]
  815. breed = math.floor(species.averageFitness / sum * Population) - 1
  816. for i=1,breed do
  817. table.insert(children, breedChild(species))
  818. end
  819. end
  820. cullSpecies(true) -- Cull all but the top member of each species
  821. while #children + #pool.species < Population do
  822. local species = pool.species[math.random(1, #pool.species)]
  823. table.insert(children, breedChild(species))
  824. end
  825. for c=1,#children do
  826. local child = children[c]
  827. addToSpecies(child)
  828. end
  829.  
  830. pool.generation = pool.generation + 1
  831. pool.maxcounter = 0 --reset max fitness counter
  832. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  833. end
  834.  
  835. function initializePool()
  836. pool = newPool()
  837.  
  838. for i=1,Population do
  839. basic = basicGenome()
  840. addToSpecies(basic)
  841. end
  842.  
  843. initializeRun()
  844. end
  845.  
  846. function clearJoypad()
  847. controller = {}
  848. for b = 1,#ButtonNames do
  849. controller["P1 " .. ButtonNames[b]] = false
  850. end
  851. joypad.set(controller)
  852. end
  853.  
  854. function initializeRun()
  855. savestate.load(Filename);
  856. rightmost = 0
  857. pool.currentFrame = 0
  858. timeout = TimeoutConstant
  859. clearJoypad()
  860. currentTracker = memory.readbyte(0x071A) --sets the current tracker to the current screen
  861.  
  862. local species = pool.species[pool.currentSpecies]
  863. local genome = species.genomes[pool.currentGenome]
  864. generateNetwork(genome)
  865. evaluateCurrent()
  866. end
  867.  
  868. function evaluateCurrent()
  869. local species = pool.species[pool.currentSpecies]
  870. local genome = species.genomes[pool.currentGenome]
  871.  
  872. inputs = getInputs()
  873. controller = evaluateNetwork(genome.network, inputs)
  874.  
  875. if controller["P1 Left"] and controller["P1 Right"] then
  876. controller["P1 Left"] = false
  877. controller["P1 Right"] = false
  878. end
  879. if controller["P1 Up"] and controller["P1 Down"] then
  880. controller["P1 Up"] = false
  881. controller["P1 Down"] = false
  882. end
  883.  
  884. joypad.set(controller)
  885. end
  886.  
  887. if pool == nil then
  888. initializePool()
  889. end
  890.  
  891.  
  892. function nextGenome()
  893. pool.currentGenome = pool.currentGenome + 1
  894. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  895. pool.currentGenome = 1
  896. pool.currentSpecies = pool.currentSpecies+1
  897. if pool.currentSpecies > #pool.species then
  898. newGeneration()
  899. pool.currentSpecies = 1
  900. end
  901. end
  902. end
  903.  
  904. function fitnessAlreadyMeasured()
  905. local species = pool.species[pool.currentSpecies]
  906. local genome = species.genomes[pool.currentGenome]
  907.  
  908. return genome.fitness ~= 0
  909. end
  910.  
  911. function displayGenome(genome)
  912. local network = genome.network
  913. local cells = {}
  914. local i = 1
  915. local cell = {}
  916. for dy=-BoxRadius,BoxRadius do
  917. for dx=-BoxRadius,BoxRadius do
  918. cell = {}
  919. cell.x = 50+5*dx
  920. cell.y = 70+5*dy
  921. cell.value = network.neurons[i].value
  922. cells[i] = cell
  923. i = i + 1
  924. end
  925. end
  926. local biasCell = {}
  927. biasCell.x = 80
  928. biasCell.y = 110
  929. biasCell.value = network.neurons[Inputs].value
  930. cells[Inputs] = biasCell
  931.  
  932. for o = 1,Outputs do
  933. cell = {}
  934. cell.x = 220
  935. cell.y = 30 + 8 * o
  936. cell.value = network.neurons[MaxNodes + o].value
  937. cells[MaxNodes+o] = cell
  938. local color
  939. if cell.value > 0 then
  940. color = 0xFF0000FF
  941. else
  942. color = 0xFF000000
  943. end
  944. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  945. end
  946.  
  947. for n,neuron in pairs(network.neurons) do
  948. cell = {}
  949. if n > Inputs and n <= MaxNodes then
  950. cell.x = 140
  951. cell.y = 40
  952. cell.value = neuron.value
  953. cells[n] = cell
  954. end
  955. end
  956.  
  957. for n=1,4 do
  958. for _,gene in pairs(genome.genes) do
  959. if gene.enabled then
  960. local c1 = cells[gene.into]
  961. local c2 = cells[gene.out]
  962. if gene.into > Inputs and gene.into <= MaxNodes then
  963. c1.x = 0.75*c1.x + 0.25*c2.x
  964. if c1.x >= c2.x then
  965. c1.x = c1.x - 40
  966. end
  967. if c1.x < 90 then
  968. c1.x = 90
  969. end
  970.  
  971. if c1.x > 220 then
  972. c1.x = 220
  973. end
  974. c1.y = 0.75*c1.y + 0.25*c2.y
  975.  
  976. end
  977. if gene.out > Inputs and gene.out <= MaxNodes then
  978. c2.x = 0.25*c1.x + 0.75*c2.x
  979. if c1.x >= c2.x then
  980. c2.x = c2.x + 40
  981. end
  982. if c2.x < 90 then
  983. c2.x = 90
  984. end
  985. if c2.x > 220 then
  986. c2.x = 220
  987. end
  988. c2.y = 0.25*c1.y + 0.75*c2.y
  989. end
  990. end
  991. end
  992. end
  993.  
  994. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  995. for n,cell in pairs(cells) do
  996. if n > Inputs or cell.value ~= 0 then
  997. local color = math.floor((cell.value+1)/2*256)
  998. if color > 255 then color = 255 end
  999. if color < 0 then color = 0 end
  1000. local opacity = 0xFF000000
  1001. if cell.value == 0 then
  1002. opacity = 0x50000000
  1003. end
  1004. color = opacity + color*0x10000 + color*0x100 + color
  1005. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  1006. end
  1007. end
  1008. for _,gene in pairs(genome.genes) do
  1009. if gene.enabled then
  1010. local c1 = cells[gene.into]
  1011. local c2 = cells[gene.out]
  1012. local opacity = 0xA0000000
  1013. if c1.value == 0 then
  1014. opacity = 0x20000000
  1015. end
  1016.  
  1017. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  1018. if gene.weight > 0 then
  1019. color = opacity + 0x8000 + 0x10000*color
  1020. else
  1021. color = opacity + 0x800000 + 0x100*color
  1022. end
  1023. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  1024. end
  1025. end
  1026.  
  1027. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1028.  
  1029. if forms.ischecked(showMutationRates) then
  1030. local pos = 100
  1031. for mutation,rate in pairs(genome.mutationRates) do
  1032. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1033. pos = pos + 8
  1034. end
  1035. end
  1036. end
  1037.  
  1038. function writeFile(filename)
  1039. local file = io.open(filename, "w")
  1040. file:write(pool.generation .. "\n")
  1041. file:write(pool.maxFitness .. "\n")
  1042. file:write(#pool.species .. "\n")
  1043. for n,species in pairs(pool.species) do
  1044. file:write(species.topFitness .. "\n")
  1045. file:write(species.staleness .. "\n")
  1046. file:write(#species.genomes .. "\n")
  1047. for m,genome in pairs(species.genomes) do
  1048. file:write(genome.fitness .. "\n")
  1049. file:write(genome.maxneuron .. "\n")
  1050. for mutation,rate in pairs(genome.mutationRates) do
  1051. file:write(mutation .. "\n")
  1052. file:write(rate .. "\n")
  1053. end
  1054. file:write("done\n")
  1055.  
  1056. file:write(#genome.genes .. "\n")
  1057. for l,gene in pairs(genome.genes) do
  1058. file:write(gene.into .. " ")
  1059. file:write(gene.out .. " ")
  1060. file:write(gene.weight .. " ")
  1061. file:write(gene.innovation .. " ")
  1062. if(gene.enabled) then
  1063. file:write("1\n")
  1064. else
  1065. file:write("0\n")
  1066. end
  1067. end
  1068. end
  1069. end
  1070. file:close()
  1071. end
  1072.  
  1073. function savePool()
  1074. local filename = forms.gettext(saveLoadFile)
  1075. writeFile(filename)
  1076. console.writeline("\nsaving complete") --suggestion by DeltaLeeds
  1077. end
  1078.  
  1079. function loadFile(filename)
  1080. local file = io.open(filename, "r")
  1081. pool = newPool()
  1082. pool.generation = file:read("*number")
  1083. pool.maxFitness = file:read("*number")
  1084. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1085. local numSpecies = file:read("*number")
  1086. for s=1,numSpecies do
  1087. local species = newSpecies()
  1088. table.insert(pool.species, species)
  1089. species.topFitness = file:read("*number")
  1090. species.staleness = file:read("*number")
  1091. local numGenomes = file:read("*number")
  1092. for g=1,numGenomes do
  1093. local genome = newGenome()
  1094. table.insert(species.genomes, genome)
  1095. genome.fitness = file:read("*number")
  1096. genome.maxneuron = file:read("*number")
  1097. local line = file:read("*line")
  1098. while line ~= "done" do
  1099. genome.mutationRates[line] = file:read("*number")
  1100. line = file:read("*line")
  1101. end
  1102. local numGenes = file:read("*number")
  1103. for n=1,numGenes do
  1104. local gene = newGene()
  1105. table.insert(genome.genes, gene)
  1106. local enabled
  1107.  
  1108. --not as elegant as reading number but at least it works
  1109. local t = {}
  1110. local a = file:read("*line")
  1111. for n in a:gmatch("%S+") do
  1112. table.insert(t, tonumber(n))
  1113. end
  1114. gene.into = t[1]
  1115. gene.out = t[2]
  1116. gene.weight = t[3]
  1117. gene.innovation = t[4]
  1118. enabled = t[5]
  1119. --old line
  1120. --gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1121.  
  1122. if enabled == 0 then
  1123. gene.enabled = false
  1124. else
  1125. gene.enabled = true
  1126. end
  1127.  
  1128. end
  1129. end
  1130. end
  1131. file:close()
  1132.  
  1133. while fitnessAlreadyMeasured() do
  1134. nextGenome()
  1135. end
  1136. initializeRun()
  1137. pool.currentFrame = pool.currentFrame + 1
  1138. end
  1139.  
  1140. function loadPool()
  1141. local filename = forms.gettext(saveLoadFile)
  1142. loadFile(filename)
  1143. end
  1144.  
  1145. function playTop()
  1146. local maxfitness = 0
  1147. local maxs, maxg
  1148. for s,species in pairs(pool.species) do
  1149. for g,genome in pairs(species.genomes) do
  1150. if genome.fitness > maxfitness then
  1151. maxfitness = genome.fitness
  1152. maxs = s
  1153. maxg = g
  1154. end
  1155. end
  1156. end
  1157.  
  1158. pool.currentSpecies = maxs
  1159. pool.currentGenome = maxg
  1160. pool.maxFitness = maxfitness
  1161. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1162. initializeRun()
  1163. pool.currentFrame = pool.currentFrame + 1
  1164. return
  1165. end
  1166.  
  1167. function onExit()
  1168. forms.destroy(form)
  1169. end
  1170.  
  1171. writeFile("temp.pool")
  1172.  
  1173. event.onexit(onExit)
  1174.  
  1175. form = forms.newform(200, 260, "Fitness")
  1176. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1177. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1178. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1179. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1180. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1181. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1182. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1183. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1184. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1185. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1186.  
  1187.  
  1188. while true do
  1189. local backgroundColor = 0xD0FFFFFF
  1190. if not forms.ischecked(hideBanner) then
  1191. gui.drawBox(0, 0, 300, 22, backgroundColor, backgroundColor)
  1192. end
  1193.  
  1194. if pool.maxcounter == nil then --allows for loading of old saves
  1195. pool.maxcounter = 0
  1196. console.writeline("caught missing variable")
  1197. end
  1198.  
  1199. local species = pool.species[pool.currentSpecies]
  1200. local genome = species.genomes[pool.currentGenome]
  1201.  
  1202. if forms.ischecked(showNetwork) then
  1203. displayGenome(genome)
  1204. end
  1205.  
  1206. if pool.currentFrame%5 == 0 then
  1207. evaluateCurrent()
  1208. end
  1209.  
  1210. joypad.set(controller)
  1211.  
  1212. getPositions()
  1213.  
  1214. if mariostate == 2 or mariostate == 3 then --detects when mario enters a pipe
  1215. marioPipeEnter = true
  1216. end
  1217. if (CurrentWorld == 3 and CurrentLevel == 4) or (CurrentWorld == 6 and CurrentLevel == 4) or (CurrentWorld == 7 and CurrentLevel == 3) then
  1218. if currentTracker == currentscreen or currentTracker == nextscreen and not mariopipe then --follows the progress of mario
  1219. currentTracker = nextscreen
  1220. elseif currentTracker > nextscreen and not mariopipe and not marioPipeEnter and mariostate ~= 7 then --if mario suddenly goes back trigger loop detection
  1221. loop = true
  1222. elseif currentscreen == 0 and nextscreen == 0 and mariopipe and mariostate ~= 7 then -- if mario enters a piperoom negate loop detection
  1223. loop = false
  1224. currentTracker = nextscreen
  1225. elseif nextscreen >= (currentTracker - 1) and mariopipe and mariostate ~= 7 then -- if mario goes forward in the level by entering a pipe negate loop detection
  1226. loop = false
  1227. currentTracker = nextscreen
  1228. pipeexit = true
  1229. else --if nothing is true then we are in a loop
  1230. if not pipeexit then --prevents the extra one frame flicker of the tracker
  1231. loop = true
  1232. else
  1233. pipeexit = false
  1234. end
  1235. end
  1236. end
  1237. if marioPipeEnter == true and mariostate == 7 then --set mariopipe when exiting the pipe
  1238. mariopipe = true
  1239. end
  1240. if mariopipe == true and mariostate ~= 7 and rightmost > marioX and not mariohole then --calculate offset
  1241. offset = rightmost - marioX
  1242. mariopipe = false
  1243. marioPipeEnter = false
  1244. elseif mariopipe == true and mariostate ~= 7 and rightmost <= marioX and not mariohole then --if the exit coordinates are higher than rightmost after exiting a pipe set offset to 0
  1245. offset = 0
  1246. mariopipe = false
  1247. marioPipeEnter = false
  1248. end
  1249.  
  1250. if marioY + memory.readbyte(0xB5)*255 > 512 then --added the fallen into a hole variable that will activate when mario goes below screen
  1251. mariohole = true
  1252. end
  1253. if marioX + offset > rightmost and mariohole == false and loop == false and marioPipeEnter == false then
  1254. rightmost = marioX + offset
  1255. if TimeoutConstant - timeout > 20 then --this is required for 4-4 so walking left will not give a big penalty
  1256. timebonus = timebonus + (TimeoutConstant - timeout) * 13/20
  1257. end
  1258. timeout = TimeoutConstant
  1259. end
  1260. if marioPipeEnter == true then --freeze fitness and timer when mario enters a pipe
  1261. timeout = timeout + 1
  1262. timebonus = timebonus + 2/3
  1263. end
  1264.  
  1265. timeout = timeout - 1
  1266.  
  1267.  
  1268. local timeoutBonus = pool.currentFrame / 4
  1269. if timeout + timeoutBonus <= 0 then
  1270. local fitness = rightmost - pool.currentFrame / 2 +20 + timebonus
  1271. --if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 + offset then
  1272. -- fitness = fitness + 1000
  1273. --end
  1274. --if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 + offset then
  1275. -- fitness = fitness + 1000
  1276. --end
  1277. if fitness == 0 then
  1278. fitness = -1
  1279. end
  1280.  
  1281. genome.fitness = fitness
  1282.  
  1283.  
  1284. if fitness > pool.maxFitness then
  1285. pool.maxFitness = fitness
  1286. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1287. --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1288. end
  1289.  
  1290. if fitness > pool.maxFitness - 30 then --counter for species that reach the max fitness
  1291. pool.maxcounter = pool.maxcounter + 1
  1292. end
  1293.  
  1294. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1295. pool.currentSpecies = 1
  1296. pool.currentGenome = 1
  1297. while fitnessAlreadyMeasured() do
  1298. offset = 0 --offset in x coordinates for pipe
  1299. timebonus = 0 --used to freeze fitness
  1300. mariohole = false --reset the fallen into hole variable
  1301. mariopipe = false --is mario exiting a pipe or just teleporting
  1302. loop = false --is this a loop?
  1303. marioPipeEnter = false --has mario entered a pipe
  1304. currentTracker = 1 --tracker for the loop
  1305. nextGenome()
  1306. end
  1307. initializeRun()
  1308. end
  1309.  
  1310. local measured = 0
  1311. local total = 0
  1312. for _,species in pairs(pool.species) do
  1313. for _,genome in pairs(species.genomes) do
  1314. total = total + 1
  1315. if genome.fitness ~= 0 then
  1316. measured = measured + 1
  1317. end
  1318. end
  1319. end
  1320. if not forms.ischecked(hideBanner) then
  1321. gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 7)
  1322. gui.drawText(0, 9, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 + - (timeout + timeoutBonus)*2/3 +20 + timebonus), 0xFF000000, 7) -- added a +20 so mario begins at 0 fitness
  1323. gui.drawText(100, 9, "Max Fitness:" .. math.floor(pool.maxFitness), 0xFF000000, 7)
  1324. gui.drawText(220, 9, "(" .. math.ceil(pool.maxcounter/Population*100) .. "%)", 0xFF000000, 7) --draws the % that reached the max fitness - 30
  1325. end
  1326.  
  1327. pool.currentFrame = pool.currentFrame + 1
  1328.  
  1329. emu.frameadvance();
  1330. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement