Code-Akisame

FCEUX script for Einfach Nerdig's stream

Apr 22nd, 2018
598
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 39.49 KB | None | 0 0
  1. -- MarI/O by SethBling rewritten for use with FCEUX by Akisame (used 3DI70R's version as a reference to save time but I have been told by him that he used juvester's script as a reference https://github.com/juvester/mari-o-fceux.)
  2. -- I would like to stress that we tried to keep the algorithm as close to the version of SethBling as possible.
  3. -- Pressing N will show the Neural net, pressing M will show the mutation rates, pressing L will load the last saved generation and pressing T will display the genome with the top fitness
  4. -- This file includes all the fixes from my bizhawk version (https://pastebin.com/dKYFKg2E) and an extra fix to hopefully prevent the formation of a monospecies when we encounter a major bottleneck
  5. -- If you want to start your own stream with this script I would appreciate it if you would include this file as well as SethBling's original version in the description
  6. -- To use this script just download FCEUX and create a savestate in state 1 for smb. click file -> lua -> 'new lua script window' and load the script.
  7. -- Have fun!
  8.  
  9. --FCEUX doesn't allow for user string input but if you press 'L' on your keyboard it should load the latest generation that was saved
  10. --for loading specific generations replace nil with the filename as shown in the example below
  11. savedpool = nil
  12. --savedpool = 'backups/backup.429.SMB1-1.state.pool'
  13. --savedpool = "backups/backup.5.SMB1-1.state.pool"
  14.  
  15. -- HUD options. These can be changed on the fly by pressing 'M' for the mutation rates and 'N' for the network
  16. mutation_rates_disp = false
  17. neural_net_disp = true
  18.  
  19. --set to false if you want warning popups before loading anything that might cause you to lose your current progress
  20. nopopup = true
  21.  
  22. savestate_slot = 1 -- Load the level from FCEUX savestate slot
  23.  
  24. SAVE_LOAD_FILE = "SMB1-1.state.pool" --filename to be used
  25.  
  26. os.execute("mkdir backups") --create the folder to save the pools in
  27.  
  28. SavestateObj = savestate.object(savestate_slot)
  29.  
  30. ButtonNames = {
  31. "A",
  32. "B",
  33. "up",
  34. "down",
  35. "left",
  36. "right",
  37. }
  38.  
  39.  
  40. BoxRadius = 6
  41. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  42.  
  43. Inputs = InputSize+1
  44. Outputs = #ButtonNames
  45.  
  46. Population = 300
  47. DeltaDisjoint = 2.0
  48. DeltaWeights = 0.4
  49. DeltaThreshold = 1.0
  50.  
  51. keyflag = false --used for turning the neural network on and off by pressin 'n'
  52.  
  53. StaleSpecies = 25 --increased slightly to allow for longer evolution
  54. mariohole = false --flag for mario falling into a hole
  55. mariopipe = false --flag for mario exiting a pipe
  56. marioPipeEnter = false --new flag for entering a pipe
  57. offset = 0 --offset value for when mario exits a pipe
  58. timebonus = 0 --timebonus to prevent the hard drops in fitness when exiting a pipe
  59. maxcounter = 0 --species that reach the max fitness
  60. currentTracker = 0 --tracks location on screen
  61. loop = false --flag for when a loop is detected
  62.  
  63. MutateConnectionsChance = 0.25
  64. PerturbChance = 0.90
  65. CrossoverChance = 0.75
  66. LinkMutationChance = 2.0
  67. NodeMutationChance = 0.50
  68. BiasMutationChance = 0.40
  69. StepSize = 0.1
  70. DisableMutationChance = 0.4
  71. EnableMutationChance = 0.2
  72.  
  73. TimeoutConstant = 90 --set to a higher value but see for yourself what is acceptable
  74.  
  75. MaxNodes = 1000000
  76.  
  77. function getPositions()
  78. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  79. marioY = memory.readbyte(0x03B8)+16
  80. mariostate = memory.readbyte(0x0E) --get mario's state. (entering exiting pipes, dying, picking up a mushroom etc etc)
  81.  
  82. currentscreen = memory.readbyte(0x071A) --finds current screen for loop detection
  83. nextscreen = memory.readbyte(0x071B) --finds next screen
  84.  
  85. CurrentWorld = memory.readbyte(0x075F) --finds the current world (value + 1 is world)
  86. CurrentLevel = memory.readbyte(0x0760) --finds the current level (0=1 2=2 3=3 4=4)
  87.  
  88. screenX = memory.readbyte(0x03AD)
  89. screenY = memory.readbyte(0x03B8)
  90. end
  91.  
  92. function getTile(dx, dy)
  93. local x = marioX + dx + 8
  94. local y = marioY + dy - 16
  95. local page = math.floor(x/256)%2
  96.  
  97. local subx = math.floor((x%256)/16)
  98. local suby = math.floor((y - 32)/16)
  99. local addr = 0x500 + page*13*16+suby*16+subx
  100.  
  101. if suby >= 13 or suby < 0 then
  102. return 0
  103. end
  104.  
  105. if memory.readbyte(addr) ~= 0 then
  106. return 1
  107. else
  108. return 0
  109. end
  110. end
  111.  
  112. function getSprites()
  113. local sprites = {}
  114. for slot=0,4 do
  115. local enemy = memory.readbyte(0xF+slot)
  116. if enemy ~= 0 then
  117. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  118. local ey = memory.readbyte(0xCF + slot)+36 --changed this to stop sprites from floating in mid air on the neural network
  119. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  120. end
  121. end
  122.  
  123. return sprites
  124. end
  125.  
  126. function getInputs()
  127. getPositions()
  128.  
  129. sprites = getSprites()
  130.  
  131. local inputs = {}
  132.  
  133. for dy=-BoxRadius*16,BoxRadius*16,16 do
  134. for dx=-BoxRadius*16,BoxRadius*16,16 do
  135. inputs[#inputs+1] = 0
  136.  
  137. tile = getTile(dx, dy)
  138. if tile == 1 and marioY+dy < 0x1B0 then
  139. inputs[#inputs] = 1
  140. end
  141.  
  142. for i = 1,#sprites do
  143. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  144. disty = math.abs(sprites[i]["y"] - (marioY+dy+16))
  145. if distx <= 8 and disty <= 8 then
  146. inputs[#inputs] = -1
  147. end
  148. end
  149. end
  150. end
  151.  
  152. --mariovx = memory.read_s8(0x7B)
  153. --mariovy = memory.read_s8(0x7D)
  154.  
  155. return inputs
  156. end
  157.  
  158. function sigmoid(x)
  159. return 2/(1+math.exp(-4.9*x))-1
  160. end
  161.  
  162. function newInnovation()
  163. pool.innovation = pool.innovation + 1
  164. return pool.innovation
  165. end
  166.  
  167. function newPool()
  168. local pool = {}
  169. pool.species = {}
  170. pool.generation = 0
  171. pool.innovation = Outputs
  172. pool.currentSpecies = 1
  173. pool.currentGenome = 1
  174. pool.currentFrame = 0
  175. pool.maxFitness = 0
  176. pool.maxcounter = 0 --counter for %max species reached
  177.  
  178. return pool
  179. end
  180.  
  181. function newSpecies()
  182. local species = {}
  183. species.topFitness = 0
  184. species.staleness = 0
  185. species.genomes = {}
  186. species.averageFitness = 0
  187.  
  188. return species
  189. end
  190.  
  191. function getCurrentGenome()
  192. local species = pool.species[pool.currentSpecies]
  193. return species.genomes[pool.currentGenome]
  194. end
  195.  
  196. function toRGBA(ARGB)
  197. return bit.lshift(ARGB, 8) + bit.rshift(ARGB, 24)
  198. end
  199.  
  200.  
  201. function newGenome()
  202. local genome = {}
  203. genome.genes = {}
  204. genome.fitness = 0
  205. genome.adjustedFitness = 0
  206. genome.network = {}
  207. genome.maxneuron = 0
  208. genome.globalRank = 0
  209. genome.mutationRates = {}
  210. genome.mutationRates["connections"] = MutateConnectionsChance
  211. genome.mutationRates["link"] = LinkMutationChance
  212. genome.mutationRates["bias"] = BiasMutationChance
  213. genome.mutationRates["node"] = NodeMutationChance
  214. genome.mutationRates["enable"] = EnableMutationChance
  215. genome.mutationRates["disable"] = DisableMutationChance
  216. genome.mutationRates["step"] = StepSize
  217.  
  218. return genome
  219. end
  220.  
  221. function copyGenome(genome)
  222. local genome2 = newGenome()
  223. for g=1,#genome.genes do
  224. table.insert(genome2.genes, copyGene(genome.genes[g]))
  225. end
  226. genome2.maxneuron = genome.maxneuron
  227. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  228. genome2.mutationRates["link"] = genome.mutationRates["link"]
  229. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  230. genome2.mutationRates["node"] = genome.mutationRates["node"]
  231. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  232. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  233.  
  234. return genome2
  235. end
  236.  
  237. function basicGenome()
  238. local genome = newGenome()
  239. local innovation = 1
  240.  
  241. genome.maxneuron = Inputs
  242. mutate(genome)
  243.  
  244. return genome
  245. end
  246.  
  247. function newGene()
  248. local gene = {}
  249. gene.into = 0
  250. gene.out = 0
  251. gene.weight = 0.0
  252. gene.enabled = true
  253. gene.innovation = 0
  254.  
  255. return gene
  256. end
  257.  
  258. function copyGene(gene)
  259. local gene2 = newGene()
  260. gene2.into = gene.into
  261. gene2.out = gene.out
  262. gene2.weight = gene.weight
  263. gene2.enabled = gene.enabled
  264. gene2.innovation = gene.innovation
  265.  
  266. return gene2
  267. end
  268.  
  269. function newNeuron()
  270. local neuron = {}
  271. neuron.incoming = {}
  272. neuron.value = 0.0
  273.  
  274. return neuron
  275. end
  276.  
  277. function generateNetwork(genome)
  278. local network = {}
  279. network.neurons = {}
  280.  
  281. for i=1,Inputs do
  282. network.neurons[i] = newNeuron()
  283. end
  284.  
  285. for o=1,Outputs do
  286. network.neurons[MaxNodes+o] = newNeuron()
  287. end
  288.  
  289. table.sort(genome.genes, function (a,b)
  290. return (a.out < b.out)
  291. end)
  292. for i=1,#genome.genes do
  293. local gene = genome.genes[i]
  294. if gene.enabled then
  295. if network.neurons[gene.out] == nil then
  296. network.neurons[gene.out] = newNeuron()
  297. end
  298. local neuron = network.neurons[gene.out]
  299. table.insert(neuron.incoming, gene)
  300. if network.neurons[gene.into] == nil then
  301. network.neurons[gene.into] = newNeuron()
  302. end
  303. end
  304. end
  305.  
  306. genome.network = network
  307. end
  308.  
  309. function evaluateNetwork(network, inputs)
  310. table.insert(inputs, 1)
  311. if #inputs ~= Inputs then
  312. emu.print("Incorrect number of neural network inputs.")
  313. return {}
  314. end
  315.  
  316. for i=1,Inputs do
  317. network.neurons[i].value = inputs[i]
  318. end
  319.  
  320. for _,neuron in pairs(network.neurons) do
  321. local sum = 0
  322. for j = 1,#neuron.incoming do
  323. local incoming = neuron.incoming[j]
  324. local other = network.neurons[incoming.into]
  325. sum = sum + incoming.weight * other.value
  326. end
  327.  
  328. if #neuron.incoming > 0 then
  329. neuron.value = sigmoid(sum)
  330. end
  331. end
  332.  
  333. local outputs = {}
  334. for o=1,Outputs do
  335. local button = ButtonNames[o]
  336. if network.neurons[MaxNodes+o].value > 0 then
  337. outputs[button] = true
  338. else
  339. outputs[button] = false
  340. end
  341. end
  342.  
  343. return outputs
  344. end
  345.  
  346. function crossover(g1, g2)
  347. -- Make sure g1 is the higher fitness genome
  348. if g2.fitness > g1.fitness then
  349. tempg = g1
  350. g1 = g2
  351. g2 = tempg
  352. end
  353.  
  354. local child = newGenome()
  355.  
  356. local innovations2 = {}
  357. for i=1,#g2.genes do
  358. local gene = g2.genes[i]
  359. innovations2[gene.innovation] = gene
  360. end
  361.  
  362. for i=1,#g1.genes do
  363. local gene1 = g1.genes[i]
  364. local gene2 = innovations2[gene1.innovation]
  365. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  366. table.insert(child.genes, copyGene(gene2))
  367. else
  368. table.insert(child.genes, copyGene(gene1))
  369. end
  370. end
  371.  
  372. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  373.  
  374. for mutation,rate in pairs(g1.mutationRates) do
  375. child.mutationRates[mutation] = rate
  376. end
  377.  
  378. return child
  379. end
  380.  
  381. function randomNeuron(genes, nonInput)
  382. local neurons = {}
  383. if not nonInput then
  384. for i=1,Inputs do
  385. neurons[i] = true
  386. end
  387. end
  388. for o=1,Outputs do
  389. neurons[MaxNodes+o] = true
  390. end
  391. for i=1,#genes do
  392. if (not nonInput) or genes[i].into > Inputs then
  393. neurons[genes[i].into] = true
  394. end
  395. if (not nonInput) or genes[i].out > Inputs then
  396. neurons[genes[i].out] = true
  397. end
  398. end
  399.  
  400. local count = 0
  401. for _,_ in pairs(neurons) do
  402. count = count + 1
  403. end
  404. local n = math.random(1, count)
  405.  
  406. for k,v in pairs(neurons) do
  407. n = n-1
  408. if n == 0 then
  409. return k
  410. end
  411. end
  412.  
  413. return 0
  414. end
  415.  
  416. function containsLink(genes, link)
  417. for i=1,#genes do
  418. local gene = genes[i]
  419. if gene.into == link.into and gene.out == link.out then
  420. return true
  421. end
  422. end
  423. end
  424.  
  425. function pointMutate(genome)
  426. local step = genome.mutationRates["step"]
  427.  
  428. for i=1,#genome.genes do
  429. local gene = genome.genes[i]
  430. if math.random() < PerturbChance then
  431. gene.weight = gene.weight + math.random() * step*2 - step
  432. else
  433. gene.weight = math.random()*4-2
  434. end
  435. end
  436. end
  437.  
  438. function linkMutate(genome, forceBias)
  439. local neuron1 = randomNeuron(genome.genes, false)
  440. local neuron2 = randomNeuron(genome.genes, true)
  441.  
  442. local newLink = newGene()
  443. if neuron1 <= Inputs and neuron2 <= Inputs then
  444. --Both input nodes
  445. return
  446. end
  447. if neuron2 <= Inputs then
  448. -- Swap output and input
  449. local temp = neuron1
  450. neuron1 = neuron2
  451. neuron2 = temp
  452. end
  453.  
  454. newLink.into = neuron1
  455. newLink.out = neuron2
  456. if forceBias then
  457. newLink.into = Inputs
  458. end
  459.  
  460. if containsLink(genome.genes, newLink) then
  461. return
  462. end
  463. newLink.innovation = newInnovation()
  464. newLink.weight = math.random()*4-2
  465.  
  466. table.insert(genome.genes, newLink)
  467. end
  468.  
  469. function nodeMutate(genome)
  470. if #genome.genes == 0 then
  471. return
  472. end
  473.  
  474. genome.maxneuron = genome.maxneuron + 1
  475.  
  476. local gene = genome.genes[math.random(1,#genome.genes)]
  477. if not gene.enabled then
  478. return
  479. end
  480. gene.enabled = false
  481.  
  482. local gene1 = copyGene(gene)
  483. gene1.out = genome.maxneuron
  484. gene1.weight = 1.0
  485. gene1.innovation = newInnovation()
  486. gene1.enabled = true
  487. table.insert(genome.genes, gene1)
  488.  
  489. local gene2 = copyGene(gene)
  490. gene2.into = genome.maxneuron
  491. gene2.innovation = newInnovation()
  492. gene2.enabled = true
  493. table.insert(genome.genes, gene2)
  494. end
  495.  
  496. function enableDisableMutate(genome, enable)
  497. local candidates = {}
  498. for _,gene in pairs(genome.genes) do
  499. if gene.enabled == not enable then
  500. table.insert(candidates, gene)
  501. end
  502. end
  503.  
  504. if #candidates == 0 then
  505. return
  506. end
  507.  
  508. local gene = candidates[math.random(1,#candidates)]
  509. gene.enabled = not gene.enabled
  510. end
  511.  
  512. function mutate(genome)
  513. for mutation,rate in pairs(genome.mutationRates) do
  514. if math.random(1,2) == 1 then
  515. genome.mutationRates[mutation] = 0.95*rate
  516. else
  517. genome.mutationRates[mutation] = 1.05263*rate
  518. end
  519. end
  520.  
  521. if math.random() < genome.mutationRates["connections"] then
  522. pointMutate(genome)
  523. end
  524.  
  525. local p = genome.mutationRates["link"]
  526. while p > 0 do
  527. if math.random() < p then
  528. linkMutate(genome, false)
  529. end
  530. p = p - 1
  531. end
  532.  
  533. p = genome.mutationRates["bias"]
  534. while p > 0 do
  535. if math.random() < p then
  536. linkMutate(genome, true)
  537. end
  538. p = p - 1
  539. end
  540.  
  541. p = genome.mutationRates["node"]
  542. while p > 0 do
  543. if math.random() < p then
  544. nodeMutate(genome)
  545. end
  546. p = p - 1
  547. end
  548.  
  549. p = genome.mutationRates["enable"]
  550. while p > 0 do
  551. if math.random() < p then
  552. enableDisableMutate(genome, true)
  553. end
  554. p = p - 1
  555. end
  556.  
  557. p = genome.mutationRates["disable"]
  558. while p > 0 do
  559. if math.random() < p then
  560. enableDisableMutate(genome, false)
  561. end
  562. p = p - 1
  563. end
  564. end
  565.  
  566. function disjoint(genes1, genes2)
  567. local i1 = {}
  568. for i = 1,#genes1 do
  569. local gene = genes1[i]
  570. i1[gene.innovation] = true
  571. end
  572.  
  573. local i2 = {}
  574. for i = 1,#genes2 do
  575. local gene = genes2[i]
  576. i2[gene.innovation] = true
  577. end
  578.  
  579. local disjointGenes = 0
  580. for i = 1,#genes1 do
  581. local gene = genes1[i]
  582. if not i2[gene.innovation] then
  583. disjointGenes = disjointGenes+1
  584. end
  585. end
  586.  
  587. for i = 1,#genes2 do
  588. local gene = genes2[i]
  589. if not i1[gene.innovation] then
  590. disjointGenes = disjointGenes+1
  591. end
  592. end
  593.  
  594. local n = math.max(#genes1, #genes2)
  595.  
  596. return disjointGenes / n
  597. end
  598.  
  599. function weights(genes1, genes2)
  600. local i2 = {}
  601. for i = 1,#genes2 do
  602. local gene = genes2[i]
  603. i2[gene.innovation] = gene
  604. end
  605.  
  606. local sum = 0
  607. local coincident = 0
  608. for i = 1,#genes1 do
  609. local gene = genes1[i]
  610. if i2[gene.innovation] ~= nil then
  611. local gene2 = i2[gene.innovation]
  612. sum = sum + math.abs(gene.weight - gene2.weight)
  613. coincident = coincident + 1
  614. end
  615. end
  616.  
  617. return sum / coincident
  618. end
  619.  
  620. function sameSpecies(genome1, genome2)
  621. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  622. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  623. return dd + dw < DeltaThreshold
  624. end
  625.  
  626. function rankGlobally()
  627. local global = {}
  628. for s = 1,#pool.species do
  629. local species = pool.species[s]
  630. for g = 1,#species.genomes do
  631. table.insert(global, species.genomes[g])
  632. end
  633. end
  634. table.sort(global, function (a,b)
  635. return (a.fitness < b.fitness)
  636. end)
  637.  
  638. for g=1,#global do
  639. global[g].globalRank = g
  640. end
  641. end
  642.  
  643. function calculateAverageFitness(species)
  644. local total = 0
  645.  
  646. for g=1,#species.genomes do
  647. local genome = species.genomes[g]
  648. total = total + genome.globalRank
  649. end
  650.  
  651. species.averageFitness = total / #species.genomes
  652. end
  653.  
  654. function totalAverageFitness()
  655. local total = 0
  656. for s = 1,#pool.species do
  657. local species = pool.species[s]
  658. total = total + species.averageFitness
  659. end
  660.  
  661. return total
  662. end
  663.  
  664. function cullSpecies(cutToOne)
  665. for s = 1,#pool.species do
  666. local species = pool.species[s]
  667.  
  668. table.sort(species.genomes, function (a,b)
  669. return (a.fitness > b.fitness)
  670. end)
  671.  
  672. local remaining = math.ceil(#species.genomes/2)
  673. if cutToOne then
  674. remaining = 1
  675. end
  676. while #species.genomes > remaining do
  677. table.remove(species.genomes)
  678. end
  679. end
  680. end
  681.  
  682. function breedChild(species)
  683. local child = {}
  684. if math.random() < CrossoverChance then
  685. g1 = species.genomes[math.random(1, #species.genomes)]
  686. g2 = species.genomes[math.random(1, #species.genomes)]
  687. child = crossover(g1, g2)
  688. else
  689. g = species.genomes[math.random(1, #species.genomes)]
  690. child = copyGenome(g)
  691. end
  692.  
  693. mutate(child)
  694.  
  695. return child
  696. end
  697.  
  698. function secondbestspecies() --added this to help prevent the formation of monospecies when MarI/O encounters a major bottleneck
  699. local bestfitness = 0
  700. local secondbestfitness = 0
  701. for s = 1,#pool.species do
  702. local species = pool.species[s]
  703.  
  704. if species.topFitness > bestfitness then
  705. secondbestfitness = bestfitness
  706. bestfitness = species.topFitness
  707. elseif species.topFitness > secondbestfitness then
  708. secondbestfitness = species.topFitness
  709. end
  710. end
  711. return secondbestfitness
  712. end
  713.  
  714.  
  715. function removeStaleSpecies()
  716. local survived = {}
  717. local secondbest = secondbestspecies()
  718. emu.print("Max Fitness: ".. pool.maxFitness ..". second best: ".. secondbest)
  719. for s = 1,#pool.species do
  720. local species = pool.species[s]
  721.  
  722. table.sort(species.genomes, function (a,b)
  723. return (a.fitness > b.fitness)
  724. end)
  725.  
  726. if species.genomes[1].fitness > species.topFitness then
  727. species.topFitness = species.genomes[1].fitness
  728. species.staleness = 0
  729. else
  730. species.staleness = species.staleness + 1
  731. end
  732. if species.staleness < StaleSpecies or species.topFitness >= secondbest then --originally species.topFitness >= pool.maxFitness
  733. table.insert(survived, species)
  734. end
  735. end
  736.  
  737. pool.species = survived
  738. end
  739.  
  740. function removeWeakSpecies()
  741. local survived = {}
  742.  
  743. local sum = totalAverageFitness()
  744. for s = 1,#pool.species do
  745. local species = pool.species[s]
  746. breed = math.floor(species.averageFitness / sum * Population)
  747. if breed >= 1 then
  748. table.insert(survived, species)
  749. end
  750. end
  751.  
  752. pool.species = survived
  753. end
  754.  
  755.  
  756. function addToSpecies(child)
  757. local foundSpecies = false
  758. for s=1,#pool.species do
  759. local species = pool.species[s]
  760. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  761. table.insert(species.genomes, child)
  762. foundSpecies = true
  763. end
  764. end
  765.  
  766. if not foundSpecies then
  767. local childSpecies = newSpecies()
  768. table.insert(childSpecies.genomes, child)
  769. table.insert(pool.species, childSpecies)
  770. end
  771. end
  772.  
  773. function newGeneration()
  774.  
  775. cullSpecies(false) -- Cull the bottom half of each species
  776. rankGlobally()
  777. removeStaleSpecies()
  778. rankGlobally()
  779. for s = 1,#pool.species do
  780. local species = pool.species[s]
  781. calculateAverageFitness(species)
  782. end
  783. removeWeakSpecies()
  784. local sum = totalAverageFitness()
  785. local children = {}
  786. for s = 1,#pool.species do
  787. local species = pool.species[s]
  788. breed = math.floor(species.averageFitness / sum * Population) - 1
  789. for i=1,breed do
  790. table.insert(children, breedChild(species))
  791. end
  792. end
  793. cullSpecies(true) -- Cull all but the top member of each species
  794. while #children + #pool.species < Population do
  795. local species = pool.species[math.random(1, #pool.species)]
  796. table.insert(children, breedChild(species))
  797. end
  798. for c=1,#children do
  799. local child = children[c]
  800. addToSpecies(child)
  801. end
  802.  
  803. pool.generation = pool.generation + 1
  804. pool.maxcounter = 0 --reset max fitness counter
  805.  
  806. writeFile("backups/backup." .. pool.generation .. "." .. SAVE_LOAD_FILE)
  807. writelatestgen() --tracker for the latest generation
  808. end
  809.  
  810. function initializePool()
  811. pool = newPool()
  812.  
  813. for i=1,Population do
  814. basic = basicGenome()
  815. addToSpecies(basic)
  816. end
  817.  
  818. initializeRun()
  819. end
  820.  
  821. function clearJoypad()
  822. controller = {}
  823. for b = 1,#ButtonNames do
  824. controller[ButtonNames[b]] = false
  825. end
  826. joypad.set(1, controller)
  827. end
  828.  
  829. function initializeRun()
  830.  
  831. savestate.load(SavestateObj);
  832. rightmost = 0
  833. pool.currentFrame = 0
  834. timeout = TimeoutConstant
  835. clearJoypad()
  836. currentTracker = memory.readbyte(0x071A) --sets the current tracker to the current screen
  837.  
  838. local species = pool.species[pool.currentSpecies]
  839. local genome = species.genomes[pool.currentGenome]
  840. generateNetwork(genome)
  841. evaluateCurrent()
  842. end
  843.  
  844. function evaluateCurrent()
  845. local species = pool.species[pool.currentSpecies]
  846. local genome = species.genomes[pool.currentGenome]
  847.  
  848. inputs = getInputs()
  849. controller = evaluateNetwork(genome.network, inputs)
  850.  
  851. if controller["left"] and controller["right"] then
  852. controller["left"] = false
  853. controller["right"] = false
  854. end
  855. if controller["up"] and controller["down"] then
  856. controller["up"] = false
  857. controller["down"] = false
  858. end
  859.  
  860. joypad.set(1, controller)
  861. end
  862.  
  863. if pool == nil then
  864. initializePool()
  865. end
  866.  
  867.  
  868. function nextGenome()
  869. pool.currentGenome = pool.currentGenome + 1
  870. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  871. pool.currentGenome = 1
  872. pool.currentSpecies = pool.currentSpecies+1
  873. if pool.currentSpecies > #pool.species then
  874. newGeneration()
  875. pool.currentSpecies = 1
  876. end
  877. end
  878. end
  879.  
  880. function fitnessAlreadyMeasured()
  881. local species = pool.species[pool.currentSpecies]
  882. local genome = species.genomes[pool.currentGenome]
  883.  
  884. return genome.fitness ~= 0
  885. end
  886.  
  887. function displayGenome(genome)
  888. gui.opacity(0.53)
  889. if not genome then return end
  890. local network = genome.network
  891. local cells = {}
  892. local i = 1
  893. local cell = {}
  894. for dy=-BoxRadius,BoxRadius do
  895. for dx=-BoxRadius,BoxRadius do
  896. cell = {}
  897. cell.x = 50+5*dx
  898. cell.y = 70+5*dy
  899. cell.value = network.neurons[i].value
  900. cells[i] = cell
  901. i = i + 1
  902. end
  903. end
  904. local biasCell = {}
  905. biasCell.x = 80
  906. biasCell.y = 110
  907. biasCell.value = network.neurons[Inputs].value
  908. cells[Inputs] = biasCell
  909.  
  910. gui.drawbox(215,32,245,82,toRGBA(0x80808080),toRGBA(0x00000000))
  911.  
  912. gui.opacity(0.85)
  913. for o = 1,Outputs do
  914. cell = {}
  915. cell.x = 220
  916. cell.y = 30 + 8 * o
  917. cell.value = network.neurons[MaxNodes + o].value
  918. cells[MaxNodes+o] = cell
  919. local color
  920. if cell.value > 0 then
  921. color = 0xFF0000FF
  922. else
  923. color = 0xFF777777-- original color = 0xFF000000
  924. end
  925. gui.drawtext(225, 26+8*o, ButtonNames[o], toRGBA(color), 0x0) --moved slightly for better alignment
  926. end
  927.  
  928. for n,neuron in pairs(network.neurons) do
  929. cell = {}
  930. if n > Inputs and n <= MaxNodes then
  931. cell.x = 140
  932. cell.y = 40
  933. cell.value = neuron.value
  934. cells[n] = cell
  935. end
  936. end
  937.  
  938. for n=1,4 do
  939. for _,gene in pairs(genome.genes) do
  940. if gene.enabled then
  941. local c1 = cells[gene.into]
  942. local c2 = cells[gene.out]
  943. if gene.into > Inputs and gene.into <= MaxNodes then
  944. c1.x = 0.75*c1.x + 0.25*c2.x
  945. if c1.x >= c2.x then
  946. c1.x = c1.x - 40
  947. end
  948. if c1.x < 90 then
  949. c1.x = 90
  950. end
  951.  
  952. if c1.x > 220 then
  953. c1.x = 220
  954. end
  955. c1.y = 0.75*c1.y + 0.25*c2.y
  956.  
  957. end
  958. if gene.out > Inputs and gene.out <= MaxNodes then
  959. c2.x = 0.25*c1.x + 0.75*c2.x
  960. if c1.x >= c2.x then
  961. c2.x = c2.x + 40
  962. end
  963. if c2.x < 90 then
  964. c2.x = 90
  965. end
  966. if c2.x > 220 then
  967. c2.x = 220
  968. end
  969. c2.y = 0.25*c1.y + 0.75*c2.y
  970. end
  971. end
  972. end
  973. end
  974.  
  975. gui.drawbox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2, toRGBA(0x80808080), toRGBA(0xFF000000))
  976. for n,cell in pairs(cells) do
  977. if n > Inputs or cell.value ~= 0 then
  978. local color = math.floor((cell.value+1)/2*256)
  979. if color > 255 then color = 255 end
  980. if color < 0 then color = 0 end
  981. local opacity = 0xFF000000
  982. if cell.value == 0 then
  983. opacity = 0x50000000
  984. end
  985. color = opacity + color*0x10000 + color*0x100 + color
  986. gui.drawbox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,toRGBA(color), toRGBA(opacity))
  987. end
  988. end
  989. for _,gene in pairs(genome.genes) do
  990. if gene.enabled then
  991. local c1 = cells[gene.into]
  992. local c2 = cells[gene.out]
  993. local opacity = 0xF0000000
  994. if pool.generation > 50 then
  995. opacity = 0x80000000
  996. end
  997. if c1.value == 0 then
  998. if pool.generation <10 then --slowly reducing the visibility of non active nodes
  999. opacity = 0x80000000
  1000. elseif pool.generation < 15 then
  1001. opacity = 0x60000000
  1002. elseif pool.generation < 25 then
  1003. opacity = 0x40000000
  1004. end
  1005. end
  1006. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  1007. if gene.weight > 0 then
  1008. color = opacity + 0x8000 + 0x10000*color
  1009. else
  1010. color = opacity + 0x800000 + 0x100*color
  1011. end
  1012. if c1.value ~= 0 or pool.generation < 50 then -- remove non active nodes from generation 50 onwards
  1013. gui.drawline(c1.x+1, c1.y, c2.x-3, c2.y, toRGBA(color))
  1014. end
  1015. end
  1016. end
  1017.  
  1018. gui.drawbox(49,71,51,78,toRGBA(0x80FF0000),toRGBA(0x00000000))
  1019. if mutation_rates_disp then
  1020. local pos = 120
  1021. for mutation,rate in pairs(genome.mutationRates) do
  1022. gui.drawtext(16, pos, mutation .. ": " .. rate, toRGBA(0xFF000000), 0x0)
  1023. pos = pos + 8
  1024. end
  1025. end
  1026. end
  1027.  
  1028. function writeFile(filename)
  1029. local file = io.open(filename, "w")
  1030. file:write(pool.generation .. "\n")
  1031. file:write(pool.maxFitness .. "\n")
  1032. file:write(#pool.species .. "\n")
  1033. for n,species in pairs(pool.species) do
  1034. file:write(species.topFitness .. "\n")
  1035. file:write(species.staleness .. "\n")
  1036. file:write(#species.genomes .. "\n")
  1037. for m,genome in pairs(species.genomes) do
  1038. file:write(genome.fitness .. "\n")
  1039. file:write(genome.maxneuron .. "\n")
  1040. for mutation,rate in pairs(genome.mutationRates) do
  1041. file:write(mutation .. "\n")
  1042. file:write(rate .. "\n")
  1043. end
  1044. file:write("done\n")
  1045.  
  1046. file:write(#genome.genes .. "\n")
  1047. for l,gene in pairs(genome.genes) do
  1048. file:write(gene.into .. " ")
  1049. file:write(gene.out .. " ")
  1050. file:write(gene.weight .. " ")
  1051. file:write(gene.innovation .. " ")
  1052. if(gene.enabled) then
  1053. file:write("1\n")
  1054. else
  1055. file:write("0\n")
  1056. end
  1057. end
  1058. end
  1059. end
  1060. file:close()
  1061. end
  1062.  
  1063. function savePool() --used to save when a new max fitness is reached
  1064. local filename = SAVE_LOAD_FILE
  1065. writeFile(filename)
  1066. emu.print("saved pool due to new max fitness") --suggestion by DeltaLeeds
  1067. end
  1068.  
  1069. function loadFile(filename)
  1070. local file = io.open(filename, "r")
  1071. pool = newPool()
  1072. pool.generation = file:read("*number")
  1073. pool.maxFitness = file:read("*number")
  1074. local numSpecies = file:read("*number")
  1075. for s=1,numSpecies do
  1076. local species = newSpecies()
  1077. table.insert(pool.species, species)
  1078. species.topFitness = file:read("*number")
  1079. species.staleness = file:read("*number")
  1080. local numGenomes = file:read("*number")
  1081. for g=1,numGenomes do
  1082. local genome = newGenome()
  1083. table.insert(species.genomes, genome)
  1084. genome.fitness = file:read("*number")
  1085. genome.maxneuron = file:read("*number")
  1086. local line = file:read("*line")
  1087. while line ~= "done" do
  1088. genome.mutationRates[line] = file:read("*number")
  1089. line = file:read("*line")
  1090. end
  1091. local numGenes = file:read("*number")
  1092. for n=1,numGenes do
  1093. local gene = newGene()
  1094. table.insert(genome.genes, gene)
  1095. local enabled
  1096. --I am able to revert back to the original loading script with FCEUX
  1097. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1098.  
  1099. if enabled == 0 then
  1100. gene.enabled = false
  1101. else
  1102. gene.enabled = true
  1103. end
  1104.  
  1105. end
  1106. end
  1107. end
  1108. file:close()
  1109.  
  1110. while fitnessAlreadyMeasured() do
  1111. nextGenome()
  1112. end
  1113. initializeRun()
  1114. pool.currentFrame = pool.currentFrame + 1
  1115. end
  1116.  
  1117.  
  1118. function playTop()
  1119. local maxfitness = 0
  1120. local maxs, maxg
  1121. for s,species in pairs(pool.species) do
  1122. for g,genome in pairs(species.genomes) do
  1123. if genome.fitness > maxfitness then
  1124. maxfitness = genome.fitness
  1125. maxs = s
  1126. maxg = g
  1127. end
  1128. end
  1129. end
  1130.  
  1131. pool.currentSpecies = maxs
  1132. pool.currentGenome = maxg
  1133. pool.maxFitness = maxfitness
  1134. initializeRun()
  1135. pool.currentFrame = pool.currentFrame + 1
  1136. return
  1137. end
  1138.  
  1139. function writelatestgen() --used to write the generation tracker to a file
  1140. local file = io.open('backups/latestgen', "w")
  1141. file:write(pool.generation)
  1142. file:close()
  1143. end
  1144.  
  1145. function readlatestgen() --used to retrieve the latest generation from a file
  1146. local file = io.open('backups/latestgen', "r")
  1147. local latestgen = file:read()
  1148. file:close()
  1149. local name = 'backups/backup.'.. latestgen .. '.' .. SAVE_LOAD_FILE
  1150. print('loaded: '.. name)
  1151. return name
  1152. end
  1153.  
  1154. function keyboardinput()
  1155. local keyboard = input.get()
  1156. if keyboard['N'] and keyflag == false then
  1157. if neural_net_disp then
  1158. neural_net_disp = false
  1159. else
  1160. neural_net_disp = true
  1161. end
  1162. keyflag = true
  1163. end
  1164. if keyboard['M'] and keyflag == false then
  1165. if mutation_rates_disp then
  1166. mutation_rates_disp = false
  1167. else
  1168. mutation_rates_disp = true
  1169. end
  1170. keyflag = true
  1171. end
  1172. if keyboard['L'] and keyflag == false then
  1173. if not nopopup then
  1174. local loadyn = input.popup('Are you sure you want to load the last saved generation?')
  1175. end
  1176. if loadyn == 'yes' or nopopup then
  1177. name = readlatestgen()
  1178. loadFile(name)
  1179. end
  1180. keyflag = true
  1181. end
  1182. if keyboard['T'] and keyflag == false then
  1183. if not nopopup then
  1184. local loadyn = input.popup('Are you sure you want to show the topFitness genome?')
  1185. end
  1186. if loadyn == 'yes' or nopopup then
  1187. playTop()
  1188. end
  1189. end
  1190. if not (keyboard['N'] or keyboard['L'] or keyboard['M']) and keyflag == true then
  1191. keyflag = false
  1192. end
  1193. end
  1194.  
  1195. if savedpool then
  1196. loadFile(savedpool)
  1197. end
  1198.  
  1199. writeFile("temp.pool")
  1200.  
  1201.  
  1202.  
  1203. while true do
  1204. keyboardinput()
  1205. local backgroundColor = toRGBA(0xD0FFFFFF)
  1206. gui.drawbox(0, 0, 200, 30, backgroundColor, backgroundColor)
  1207.  
  1208. if pool.maxcounter == nil then
  1209. pool.maxcounter = 0
  1210. emu.print("caught missing variable")
  1211. end
  1212.  
  1213. local species = pool.species[pool.currentSpecies]
  1214. local genome = species.genomes[pool.currentGenome]
  1215.  
  1216. if neural_net_disp then
  1217. displayGenome(genome)
  1218. end
  1219. if pool.currentFrame%5 == 0 then
  1220. evaluateCurrent()
  1221. end
  1222.  
  1223. joypad.set(1, controller)
  1224.  
  1225. getPositions()
  1226.  
  1227. if mariostate == 2 or mariostate == 3 then --detects when mario enters a pipe
  1228. marioPipeEnter = true
  1229. end
  1230. if (CurrentWorld == 3 and CurrentLevel == 4) or (CurrentWorld == 6 and CurrentLevel == 4) or (CurrentWorld == 7 and CurrentLevel == 3) then
  1231. if currentTracker == currentscreen or currentTracker == nextscreen and not mariopipe then --follows the progress of mario
  1232. currentTracker = nextscreen
  1233. elseif currentTracker > nextscreen and not mariopipe and not marioPipeEnter and mariostate ~= 7 then --if mario suddenly goes back trigger loop detection
  1234. loop = true
  1235. elseif currentscreen == 0 and nextscreen == 0 and mariopipe and mariostate ~= 7 then -- if mario enters a piperoom negate loop detection
  1236. loop = false
  1237. currentTracker = nextscreen
  1238. elseif nextscreen >= (currentTracker - 1) and mariopipe and mariostate ~= 7 then -- if mario goes forward in the level by entering a pipe negate loop detection
  1239. loop = false
  1240. currentTracker = nextscreen
  1241. pipeexit = true
  1242. else --if nothing is true then we are in a loop
  1243. if not pipeexit then --prevents the extra one frame flicker of the tracker
  1244. loop = true
  1245. else
  1246. pipeexit = false
  1247. end
  1248. end
  1249. end
  1250. if marioPipeEnter == true and mariostate == 7 then --set mariopipe when exiting the pipe
  1251. mariopipe = true
  1252. end
  1253. if mariopipe == true and mariostate ~= 7 and rightmost > marioX and not mariohole then --calculate offset
  1254. offset = rightmost - marioX
  1255. mariopipe = false
  1256. marioPipeEnter = false
  1257. elseif mariopipe == true and mariostate ~= 7 and rightmost <= marioX and not mariohole then --if the exit coordinates are higher than rightmost after exiting a pipe set offset to 0
  1258. offset = 0
  1259. mariopipe = false
  1260. marioPipeEnter = false
  1261. end
  1262.  
  1263. if marioY + memory.readbyte(0xB5)*255 > 512 then --added the fallen into a hole variable that will activate when mario goes below screen
  1264. mariohole = true
  1265. end
  1266. if marioX + offset > rightmost and mariohole == false and loop == false and marioPipeEnter == false then
  1267. rightmost = marioX + offset
  1268. if TimeoutConstant - timeout > 20 then --this is required for 4-4 so walking left will not give a big penalty
  1269. timebonus = timebonus + (TimeoutConstant - timeout) * 13/20
  1270. end
  1271. timeout = TimeoutConstant
  1272. end
  1273. if marioPipeEnter == true then --freeze fitness and timer when mario enters a pipe
  1274. timeout = timeout + 1
  1275. timebonus = timebonus + 2/3
  1276. end
  1277.  
  1278. timeout = timeout - 1
  1279.  
  1280.  
  1281. local timeoutBonus = pool.currentFrame / 4
  1282. if timeout + timeoutBonus <= 0 then
  1283. local fitness = rightmost - pool.currentFrame / 2 +20 + timebonus
  1284. if fitness == 0 then
  1285. fitness = -1
  1286. end
  1287.  
  1288. genome.fitness = fitness
  1289.  
  1290.  
  1291. if fitness > pool.maxFitness then
  1292. pool.maxFitness = fitness
  1293. savePool()
  1294. end
  1295.  
  1296. if fitness > pool.maxFitness - 30 then
  1297. pool.maxcounter = pool.maxcounter + 1
  1298. end
  1299.  
  1300. emu.print("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1301. pool.currentSpecies = 1
  1302. pool.currentGenome = 1
  1303. while fitnessAlreadyMeasured() do
  1304. offset = 0 --offset in x coordinates for pipe
  1305. timebonus = 0 --used to freeze fitness
  1306. mariohole = false --reset the fallen into hole variable
  1307. mariopipe = false --is mario exiting a pipe or just teleporting
  1308. loop = false --is this a loop?
  1309. marioPipeEnter = false --has mario entered a pipe
  1310. currentTracker = 1 --tracker for the loop
  1311. nextGenome()
  1312. end
  1313. initializeRun()
  1314. end
  1315.  
  1316. local measured = 0
  1317. local total = 0
  1318. for _,species in pairs(pool.species) do
  1319. for _,genome in pairs(species.genomes) do
  1320. total = total + 1
  1321. if genome.fitness ~= 0 then
  1322. measured = measured + 1
  1323. end
  1324. end
  1325. end
  1326. gui.opacity(1) --made the top banner opaque. you can change this if you want a slightly transparent banner
  1327. gui.drawtext(5, 11, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", toRGBA(0xFF000000), 0x0)
  1328. gui.drawtext(5, 20, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 + - (timeout + timeoutBonus)*2/3 +20 + timebonus), toRGBA(0xFF000000), 0x0) --the (timeout + timeoutBonus)*2/3 makes sure that the displayed fitness remains stable for the viewers pleasure
  1329. gui.drawtext(80, 20, "Max Fitness:" .. math.floor(pool.maxFitness).. " ".. "(" .. math.ceil(pool.maxcounter/Population*100) .. "%)", toRGBA(0xFF000000), 0x0)
  1330.  
  1331. pool.currentFrame = pool.currentFrame + 1
  1332.  
  1333. emu.frameadvance();
  1334. end
Add Comment
Please, Sign In to add comment