Advertisement
Code-Akisame

MarI/O pastebin for FCEUX

Mar 21st, 2018
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 35.37 KB | None | 0 0
  1. -- MarI/O by SethBling with some very minor tweaks by Akisame
  2. -- The tweaks by Akisame and Anonymous to fix the pit bug are in line 48, 1205-1208 and 1248
  3. -- Further tweaks by Akisame can be found in line 47-52, 64, 82, 221, 821, 1179, 1182-1184, 1201-1217, 1224, 1244-1246, 1252-1255, 1272-1275
  4. -- The loading bug fix by Anonymous can be found in line 1096-1108
  5. -- These tweaks changes the fontsize of the banner and add a counter to check how many genomes reached maxFitness - 30
  6. -- They also fix the pipe fitness bug (fitness freezing when entering a pipe) as well as the fitness drop bug (hard fitness drop after pipe).
  7. -- Feel free to use this code, but please do not redistribute it.
  8. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  9. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  10. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  11.  
  12.  
  13. LOAD_FROM_FILE = nil
  14. --LOAD_FROM_FILE = "backups/backup.5.SMB1-1.state.pool"
  15.  
  16. -- HUD options. Use "false" to hide elements (might improve performance).
  17. SHOW_NETWORK = true
  18. SHOW_MUTATION_RATES = true
  19.  
  20.  
  21. -- Load the level from FCEUX savestate slot
  22. SAVESTATE_SLOT = 1
  23.  
  24.  
  25. SAVE_LOAD_FILE = "SMB1-1.state.pool"
  26. RECORD_LOAD_FILE = "SMB1-1.rec"
  27.  
  28. os.execute("mkdir backups")
  29.  
  30. SavestateObj = savestate.object(SAVESTATE_SLOT)
  31.  
  32.  
  33.  
  34.  
  35.  
  36. ButtonNames = {
  37. "A",
  38. "B",
  39. "up",
  40. "down",
  41. "left",
  42. "right",
  43. }
  44.  
  45.  
  46. BoxRadius = 6
  47. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  48.  
  49. Inputs = InputSize+1
  50. Outputs = #ButtonNames
  51.  
  52. Population = 300
  53. DeltaDisjoint = 2.0
  54. DeltaWeights = 0.4
  55. DeltaThreshold = 1.0
  56.  
  57. StaleSpecies = 25 --increased slightly to allow for longer evolution
  58. mariohole = false --flag for mario falling into a hole
  59. mariopipe = false --flag for mario exiting a pipe
  60. offset = 0 --offset value for when mario exits a pipe
  61. timebonus = 0 --timebonus to prevent the hard drops in fitness when exiting a pipe
  62. maxcounter = 0
  63.  
  64. MutateConnectionsChance = 0.25
  65. PerturbChance = 0.90
  66. CrossoverChance = 0.75
  67. LinkMutationChance = 2.0
  68. NodeMutationChance = 0.50
  69. BiasMutationChance = 0.40
  70. StepSize = 0.1
  71. DisableMutationChance = 0.4
  72. EnableMutationChance = 0.2
  73.  
  74. TimeoutConstant = 90 --set to a higher value but see for yourself what is acceptable
  75.  
  76. MaxNodes = 1000000
  77.  
  78. function getPositions()
  79. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  80. marioY = memory.readbyte(0x03B8)+16
  81. mariostate = memory.readbyte(0x0E) --get mario's state. (entering exiting pipes, dying, picking up a mushroom etc etc)
  82.  
  83. screenX = memory.readbyte(0x03AD)
  84. screenY = memory.readbyte(0x03B8)
  85. end
  86.  
  87. function getTile(dx, dy)
  88. local x = marioX + dx + 8
  89. local y = marioY + dy - 16
  90. local page = math.floor(x/256)%2
  91.  
  92. local subx = math.floor((x%256)/16)
  93. local suby = math.floor((y - 32)/16)
  94. local addr = 0x500 + page*13*16+suby*16+subx
  95.  
  96. if suby >= 13 or suby < 0 then
  97. return 0
  98. end
  99.  
  100. if memory.readbyte(addr) ~= 0 then
  101. return 1
  102. else
  103. return 0
  104. end
  105. end
  106.  
  107. function getSprites()
  108. local sprites = {}
  109. for slot=0,4 do
  110. local enemy = memory.readbyte(0xF+slot)
  111. if enemy ~= 0 then
  112. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  113. local ey = memory.readbyte(0xCF + slot)+24
  114. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  115. end
  116. end
  117.  
  118. return sprites
  119. end
  120.  
  121. function getInputs()
  122. getPositions()
  123.  
  124. sprites = getSprites()
  125.  
  126. local inputs = {}
  127.  
  128. for dy=-BoxRadius*16,BoxRadius*16,16 do
  129. for dx=-BoxRadius*16,BoxRadius*16,16 do
  130. inputs[#inputs+1] = 0
  131.  
  132. tile = getTile(dx, dy)
  133. if tile == 1 and marioY+dy < 0x1B0 then
  134. inputs[#inputs] = 1
  135. end
  136.  
  137. for i = 1,#sprites do
  138. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  139. disty = math.abs(sprites[i]["y"] - (marioY+dy+16))
  140. if distx <= 8 and disty <= 8 then
  141. inputs[#inputs] = -1
  142. end
  143. end
  144. end
  145. end
  146.  
  147. --mariovx = memory.read_s8(0x7B)
  148. --mariovy = memory.read_s8(0x7D)
  149.  
  150. return inputs
  151. end
  152.  
  153. function sigmoid(x)
  154. return 2/(1+math.exp(-4.9*x))-1
  155. end
  156.  
  157. function newInnovation()
  158. pool.innovation = pool.innovation + 1
  159. return pool.innovation
  160. end
  161.  
  162. function newPool()
  163. local pool = {}
  164. pool.species = {}
  165. pool.generation = 0
  166. pool.innovation = Outputs
  167. pool.currentSpecies = 1
  168. pool.currentGenome = 1
  169. pool.currentFrame = 0
  170. pool.maxFitness = 0
  171. pool.maxcounter = 0 --counter for %max species reached
  172.  
  173. return pool
  174. end
  175.  
  176. function newSpecies()
  177. local species = {}
  178. species.topFitness = 0
  179. species.staleness = 0
  180. species.genomes = {}
  181. species.averageFitness = 0
  182.  
  183. return species
  184. end
  185.  
  186. function getCurrentGenome()
  187. local species = pool.species[pool.currentSpecies]
  188. return species.genomes[pool.currentGenome]
  189. end
  190.  
  191. function toRGBA(ARGB)
  192. return bit.lshift(ARGB, 8) + bit.rshift(ARGB, 24)
  193. end
  194.  
  195.  
  196. function newGenome()
  197. local genome = {}
  198. genome.genes = {}
  199. genome.fitness = 0
  200. genome.adjustedFitness = 0
  201. genome.network = {}
  202. genome.maxneuron = 0
  203. genome.globalRank = 0
  204. genome.mutationRates = {}
  205. genome.mutationRates["connections"] = MutateConnectionsChance
  206. genome.mutationRates["link"] = LinkMutationChance
  207. genome.mutationRates["bias"] = BiasMutationChance
  208. genome.mutationRates["node"] = NodeMutationChance
  209. genome.mutationRates["enable"] = EnableMutationChance
  210. genome.mutationRates["disable"] = DisableMutationChance
  211. genome.mutationRates["step"] = StepSize
  212.  
  213. return genome
  214. end
  215.  
  216. function copyGenome(genome)
  217. local genome2 = newGenome()
  218. for g=1,#genome.genes do
  219. table.insert(genome2.genes, copyGene(genome.genes[g]))
  220. end
  221. genome2.maxneuron = genome.maxneuron
  222. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  223. genome2.mutationRates["link"] = genome.mutationRates["link"]
  224. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  225. genome2.mutationRates["node"] = genome.mutationRates["node"]
  226. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  227. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  228.  
  229. return genome2
  230. end
  231.  
  232. function basicGenome()
  233. local genome = newGenome()
  234. local innovation = 1
  235.  
  236. genome.maxneuron = Inputs
  237. mutate(genome)
  238.  
  239. return genome
  240. end
  241.  
  242. function newGene()
  243. local gene = {}
  244. gene.into = 0
  245. gene.out = 0
  246. gene.weight = 0.0
  247. gene.enabled = true
  248. gene.innovation = 0
  249.  
  250. return gene
  251. end
  252.  
  253. function copyGene(gene)
  254. local gene2 = newGene()
  255. gene2.into = gene.into
  256. gene2.out = gene.out
  257. gene2.weight = gene.weight
  258. gene2.enabled = gene.enabled
  259. gene2.innovation = gene.innovation
  260.  
  261. return gene2
  262. end
  263.  
  264. function newNeuron()
  265. local neuron = {}
  266. neuron.incoming = {}
  267. neuron.value = 0.0
  268.  
  269. return neuron
  270. end
  271.  
  272. function generateNetwork(genome)
  273. local network = {}
  274. network.neurons = {}
  275.  
  276. for i=1,Inputs do
  277. network.neurons[i] = newNeuron()
  278. end
  279.  
  280. for o=1,Outputs do
  281. network.neurons[MaxNodes+o] = newNeuron()
  282. end
  283.  
  284. table.sort(genome.genes, function (a,b)
  285. return (a.out < b.out)
  286. end)
  287. for i=1,#genome.genes do
  288. local gene = genome.genes[i]
  289. if gene.enabled then
  290. if network.neurons[gene.out] == nil then
  291. network.neurons[gene.out] = newNeuron()
  292. end
  293. local neuron = network.neurons[gene.out]
  294. table.insert(neuron.incoming, gene)
  295. if network.neurons[gene.into] == nil then
  296. network.neurons[gene.into] = newNeuron()
  297. end
  298. end
  299. end
  300.  
  301. genome.network = network
  302. end
  303.  
  304. function evaluateNetwork(network, inputs)
  305. table.insert(inputs, 1)
  306. if #inputs ~= Inputs then
  307. emu.print("Incorrect number of neural network inputs.")
  308. return {}
  309. end
  310.  
  311. for i=1,Inputs do
  312. network.neurons[i].value = inputs[i]
  313. end
  314.  
  315. for _,neuron in pairs(network.neurons) do
  316. local sum = 0
  317. for j = 1,#neuron.incoming do
  318. local incoming = neuron.incoming[j]
  319. local other = network.neurons[incoming.into]
  320. sum = sum + incoming.weight * other.value
  321. end
  322.  
  323. if #neuron.incoming > 0 then
  324. neuron.value = sigmoid(sum)
  325. end
  326. end
  327.  
  328. local outputs = {}
  329. for o=1,Outputs do
  330. local button = ButtonNames[o]
  331. if network.neurons[MaxNodes+o].value > 0 then
  332. outputs[button] = true
  333. else
  334. outputs[button] = false
  335. end
  336. end
  337.  
  338. return outputs
  339. end
  340.  
  341. function crossover(g1, g2)
  342. -- Make sure g1 is the higher fitness genome
  343. if g2.fitness > g1.fitness then
  344. tempg = g1
  345. g1 = g2
  346. g2 = tempg
  347. end
  348.  
  349. local child = newGenome()
  350.  
  351. local innovations2 = {}
  352. for i=1,#g2.genes do
  353. local gene = g2.genes[i]
  354. innovations2[gene.innovation] = gene
  355. end
  356.  
  357. for i=1,#g1.genes do
  358. local gene1 = g1.genes[i]
  359. local gene2 = innovations2[gene1.innovation]
  360. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  361. table.insert(child.genes, copyGene(gene2))
  362. else
  363. table.insert(child.genes, copyGene(gene1))
  364. end
  365. end
  366.  
  367. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  368.  
  369. for mutation,rate in pairs(g1.mutationRates) do
  370. child.mutationRates[mutation] = rate
  371. end
  372.  
  373. return child
  374. end
  375.  
  376. function randomNeuron(genes, nonInput)
  377. local neurons = {}
  378. if not nonInput then
  379. for i=1,Inputs do
  380. neurons[i] = true
  381. end
  382. end
  383. for o=1,Outputs do
  384. neurons[MaxNodes+o] = true
  385. end
  386. for i=1,#genes do
  387. if (not nonInput) or genes[i].into > Inputs then
  388. neurons[genes[i].into] = true
  389. end
  390. if (not nonInput) or genes[i].out > Inputs then
  391. neurons[genes[i].out] = true
  392. end
  393. end
  394.  
  395. local count = 0
  396. for _,_ in pairs(neurons) do
  397. count = count + 1
  398. end
  399. local n = math.random(1, count)
  400.  
  401. for k,v in pairs(neurons) do
  402. n = n-1
  403. if n == 0 then
  404. return k
  405. end
  406. end
  407.  
  408. return 0
  409. end
  410.  
  411. function containsLink(genes, link)
  412. for i=1,#genes do
  413. local gene = genes[i]
  414. if gene.into == link.into and gene.out == link.out then
  415. return true
  416. end
  417. end
  418. end
  419.  
  420. function pointMutate(genome)
  421. local step = genome.mutationRates["step"]
  422.  
  423. for i=1,#genome.genes do
  424. local gene = genome.genes[i]
  425. if math.random() < PerturbChance then
  426. gene.weight = gene.weight + math.random() * step*2 - step
  427. else
  428. gene.weight = math.random()*4-2
  429. end
  430. end
  431. end
  432.  
  433. function linkMutate(genome, forceBias)
  434. local neuron1 = randomNeuron(genome.genes, false)
  435. local neuron2 = randomNeuron(genome.genes, true)
  436.  
  437. local newLink = newGene()
  438. if neuron1 <= Inputs and neuron2 <= Inputs then
  439. --Both input nodes
  440. return
  441. end
  442. if neuron2 <= Inputs then
  443. -- Swap output and input
  444. local temp = neuron1
  445. neuron1 = neuron2
  446. neuron2 = temp
  447. end
  448.  
  449. newLink.into = neuron1
  450. newLink.out = neuron2
  451. if forceBias then
  452. newLink.into = Inputs
  453. end
  454.  
  455. if containsLink(genome.genes, newLink) then
  456. return
  457. end
  458. newLink.innovation = newInnovation()
  459. newLink.weight = math.random()*4-2
  460.  
  461. table.insert(genome.genes, newLink)
  462. end
  463.  
  464. function nodeMutate(genome)
  465. if #genome.genes == 0 then
  466. return
  467. end
  468.  
  469. genome.maxneuron = genome.maxneuron + 1
  470.  
  471. local gene = genome.genes[math.random(1,#genome.genes)]
  472. if not gene.enabled then
  473. return
  474. end
  475. gene.enabled = false
  476.  
  477. local gene1 = copyGene(gene)
  478. gene1.out = genome.maxneuron
  479. gene1.weight = 1.0
  480. gene1.innovation = newInnovation()
  481. gene1.enabled = true
  482. table.insert(genome.genes, gene1)
  483.  
  484. local gene2 = copyGene(gene)
  485. gene2.into = genome.maxneuron
  486. gene2.innovation = newInnovation()
  487. gene2.enabled = true
  488. table.insert(genome.genes, gene2)
  489. end
  490.  
  491. function enableDisableMutate(genome, enable)
  492. local candidates = {}
  493. for _,gene in pairs(genome.genes) do
  494. if gene.enabled == not enable then
  495. table.insert(candidates, gene)
  496. end
  497. end
  498.  
  499. if #candidates == 0 then
  500. return
  501. end
  502.  
  503. local gene = candidates[math.random(1,#candidates)]
  504. gene.enabled = not gene.enabled
  505. end
  506.  
  507. function mutate(genome)
  508. for mutation,rate in pairs(genome.mutationRates) do
  509. if math.random(1,2) == 1 then
  510. genome.mutationRates[mutation] = 0.95*rate
  511. else
  512. genome.mutationRates[mutation] = 1.05263*rate
  513. end
  514. end
  515.  
  516. if math.random() < genome.mutationRates["connections"] then
  517. pointMutate(genome)
  518. end
  519.  
  520. local p = genome.mutationRates["link"]
  521. while p > 0 do
  522. if math.random() < p then
  523. linkMutate(genome, false)
  524. end
  525. p = p - 1
  526. end
  527.  
  528. p = genome.mutationRates["bias"]
  529. while p > 0 do
  530. if math.random() < p then
  531. linkMutate(genome, true)
  532. end
  533. p = p - 1
  534. end
  535.  
  536. p = genome.mutationRates["node"]
  537. while p > 0 do
  538. if math.random() < p then
  539. nodeMutate(genome)
  540. end
  541. p = p - 1
  542. end
  543.  
  544. p = genome.mutationRates["enable"]
  545. while p > 0 do
  546. if math.random() < p then
  547. enableDisableMutate(genome, true)
  548. end
  549. p = p - 1
  550. end
  551.  
  552. p = genome.mutationRates["disable"]
  553. while p > 0 do
  554. if math.random() < p then
  555. enableDisableMutate(genome, false)
  556. end
  557. p = p - 1
  558. end
  559. end
  560.  
  561. function disjoint(genes1, genes2)
  562. local i1 = {}
  563. for i = 1,#genes1 do
  564. local gene = genes1[i]
  565. i1[gene.innovation] = true
  566. end
  567.  
  568. local i2 = {}
  569. for i = 1,#genes2 do
  570. local gene = genes2[i]
  571. i2[gene.innovation] = true
  572. end
  573.  
  574. local disjointGenes = 0
  575. for i = 1,#genes1 do
  576. local gene = genes1[i]
  577. if not i2[gene.innovation] then
  578. disjointGenes = disjointGenes+1
  579. end
  580. end
  581.  
  582. for i = 1,#genes2 do
  583. local gene = genes2[i]
  584. if not i1[gene.innovation] then
  585. disjointGenes = disjointGenes+1
  586. end
  587. end
  588.  
  589. local n = math.max(#genes1, #genes2)
  590.  
  591. return disjointGenes / n
  592. end
  593.  
  594. function weights(genes1, genes2)
  595. local i2 = {}
  596. for i = 1,#genes2 do
  597. local gene = genes2[i]
  598. i2[gene.innovation] = gene
  599. end
  600.  
  601. local sum = 0
  602. local coincident = 0
  603. for i = 1,#genes1 do
  604. local gene = genes1[i]
  605. if i2[gene.innovation] ~= nil then
  606. local gene2 = i2[gene.innovation]
  607. sum = sum + math.abs(gene.weight - gene2.weight)
  608. coincident = coincident + 1
  609. end
  610. end
  611.  
  612. return sum / coincident
  613. end
  614.  
  615. function sameSpecies(genome1, genome2)
  616. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  617. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  618. return dd + dw < DeltaThreshold
  619. end
  620.  
  621. function rankGlobally()
  622. local global = {}
  623. for s = 1,#pool.species do
  624. local species = pool.species[s]
  625. for g = 1,#species.genomes do
  626. table.insert(global, species.genomes[g])
  627. end
  628. end
  629. table.sort(global, function (a,b)
  630. return (a.fitness < b.fitness)
  631. end)
  632.  
  633. for g=1,#global do
  634. global[g].globalRank = g
  635. end
  636. end
  637.  
  638. function calculateAverageFitness(species)
  639. local total = 0
  640.  
  641. for g=1,#species.genomes do
  642. local genome = species.genomes[g]
  643. total = total + genome.globalRank
  644. end
  645.  
  646. species.averageFitness = total / #species.genomes
  647. end
  648.  
  649. function totalAverageFitness()
  650. local total = 0
  651. for s = 1,#pool.species do
  652. local species = pool.species[s]
  653. total = total + species.averageFitness
  654. end
  655.  
  656. return total
  657. end
  658.  
  659. function cullSpecies(cutToOne)
  660. for s = 1,#pool.species do
  661. local species = pool.species[s]
  662.  
  663. table.sort(species.genomes, function (a,b)
  664. return (a.fitness > b.fitness)
  665. end)
  666.  
  667. local remaining = math.ceil(#species.genomes/2)
  668. if cutToOne then
  669. remaining = 1
  670. end
  671. while #species.genomes > remaining do
  672. table.remove(species.genomes)
  673. end
  674. end
  675. end
  676.  
  677. function breedChild(species)
  678. local child = {}
  679. if math.random() < CrossoverChance then
  680. g1 = species.genomes[math.random(1, #species.genomes)]
  681. g2 = species.genomes[math.random(1, #species.genomes)]
  682. child = crossover(g1, g2)
  683. else
  684. g = species.genomes[math.random(1, #species.genomes)]
  685. child = copyGenome(g)
  686. end
  687.  
  688. mutate(child)
  689.  
  690. return child
  691. end
  692.  
  693. function removeStaleSpecies()
  694. local survived = {}
  695.  
  696. for s = 1,#pool.species do
  697. local species = pool.species[s]
  698.  
  699. table.sort(species.genomes, function (a,b)
  700. return (a.fitness > b.fitness)
  701. end)
  702.  
  703. if species.genomes[1].fitness > species.topFitness then
  704. species.topFitness = species.genomes[1].fitness
  705. species.staleness = 0
  706. else
  707. species.staleness = species.staleness + 1
  708. end
  709. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  710. table.insert(survived, species)
  711. end
  712. end
  713.  
  714. pool.species = survived
  715. end
  716.  
  717. function removeWeakSpecies()
  718. local survived = {}
  719.  
  720. local sum = totalAverageFitness()
  721. for s = 1,#pool.species do
  722. local species = pool.species[s]
  723. breed = math.floor(species.averageFitness / sum * Population)
  724. if breed >= 1 then
  725. table.insert(survived, species)
  726. end
  727. end
  728.  
  729. pool.species = survived
  730. end
  731.  
  732.  
  733. function addToSpecies(child)
  734. local foundSpecies = false
  735. for s=1,#pool.species do
  736. local species = pool.species[s]
  737. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  738. table.insert(species.genomes, child)
  739. foundSpecies = true
  740. end
  741. end
  742.  
  743. if not foundSpecies then
  744. local childSpecies = newSpecies()
  745. table.insert(childSpecies.genomes, child)
  746. table.insert(pool.species, childSpecies)
  747. end
  748. end
  749.  
  750. function newGeneration()
  751.  
  752. cullSpecies(false) -- Cull the bottom half of each species
  753. rankGlobally()
  754. removeStaleSpecies()
  755. rankGlobally()
  756. for s = 1,#pool.species do
  757. local species = pool.species[s]
  758. calculateAverageFitness(species)
  759. end
  760. removeWeakSpecies()
  761. local sum = totalAverageFitness()
  762. local children = {}
  763. for s = 1,#pool.species do
  764. local species = pool.species[s]
  765. breed = math.floor(species.averageFitness / sum * Population) - 1
  766. for i=1,breed do
  767. table.insert(children, breedChild(species))
  768. end
  769. end
  770. cullSpecies(true) -- Cull all but the top member of each species
  771. while #children + #pool.species < Population do
  772. local species = pool.species[math.random(1, #pool.species)]
  773. table.insert(children, breedChild(species))
  774. end
  775. for c=1,#children do
  776. local child = children[c]
  777. addToSpecies(child)
  778. end
  779.  
  780. pool.generation = pool.generation + 1
  781. pool.maxcounter = 0 --reset max fitness counter
  782.  
  783. writeFile("backups/backup." .. pool.generation .. "." .. SAVE_LOAD_FILE)
  784.  
  785. --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  786. end
  787.  
  788. function initializePool()
  789. pool = newPool()
  790.  
  791. for i=1,Population do
  792. basic = basicGenome()
  793. addToSpecies(basic)
  794. end
  795.  
  796. initializeRun()
  797. end
  798.  
  799. function clearJoypad()
  800. controller = {}
  801. for b = 1,#ButtonNames do
  802. controller[ButtonNames[b]] = false
  803. end
  804. joypad.set(1, controller)
  805. end
  806.  
  807. function initializeRun()
  808.  
  809. savestate.load(SavestateObj);
  810. rightmost = 0
  811. pool.currentFrame = 0
  812. timeout = TimeoutConstant
  813. clearJoypad()
  814.  
  815. local species = pool.species[pool.currentSpecies]
  816. local genome = species.genomes[pool.currentGenome]
  817. generateNetwork(genome)
  818. evaluateCurrent()
  819. end
  820.  
  821. function evaluateCurrent()
  822. --local species = pool.species[pool.currentSpecies]
  823. --local genome = species.genomes[pool.currentGenome]
  824. local genome = getCurrentGenome()
  825.  
  826.  
  827. inputs = getInputs()
  828. controller = evaluateNetwork(genome.network, inputs)
  829.  
  830. if controller["P1 Left"] and controller["P1 Right"] then
  831. controller["P1 Left"] = false
  832. controller["P1 Right"] = false
  833. end
  834. if controller["P1 Up"] and controller["P1 Down"] then
  835. controller["P1 Up"] = false
  836. controller["P1 Down"] = false
  837. end
  838.  
  839. -- joypad.set(controller)
  840.  
  841. -- if controller["left"] and controller["right"] then
  842. -- controller["left"] = false
  843. -- controller["right"] = false
  844. -- end
  845. -- if controller["up"] and controller["down"] then
  846. -- controller["up"] = false
  847. -- controller["down"] = false
  848. -- end
  849.  
  850. joypad.set(1, controller)
  851. end
  852.  
  853. if pool == nil then
  854. initializePool()
  855. end
  856.  
  857.  
  858. function nextGenome()
  859. pool.currentGenome = pool.currentGenome + 1
  860. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  861. pool.currentGenome = 1
  862. pool.currentSpecies = pool.currentSpecies+1
  863. if pool.currentSpecies > #pool.species then
  864. newGeneration()
  865. pool.currentSpecies = 1
  866. end
  867. end
  868. end
  869.  
  870. function fitnessAlreadyMeasured()
  871. local species = pool.species[pool.currentSpecies]
  872. local genome = species.genomes[pool.currentGenome]
  873.  
  874. return genome.fitness ~= 0
  875. end
  876.  
  877. function displayGenome(genome)
  878. if not genome then return end
  879. local network = genome.network
  880. local cells = {}
  881. local i = 1
  882. local cell = {}
  883. for dy=-BoxRadius,BoxRadius do
  884. for dx=-BoxRadius,BoxRadius do
  885. cell = {}
  886. cell.x = 50+5*dx
  887. cell.y = 70+5*dy
  888. cell.value = network.neurons[i].value
  889. cells[i] = cell
  890. i = i + 1
  891. end
  892. end
  893. local biasCell = {}
  894. biasCell.x = 80
  895. biasCell.y = 110
  896. biasCell.value = network.neurons[Inputs].value
  897. cells[Inputs] = biasCell
  898.  
  899. for o = 1,Outputs do
  900. cell = {}
  901. cell.x = 220
  902. cell.y = 30 + 8 * o
  903. cell.value = network.neurons[MaxNodes + o].value
  904. cells[MaxNodes+o] = cell
  905. local color
  906. if cell.value > 0 then
  907. color = 0xFF0000FF
  908. else
  909. color = 0xFF000000
  910. end
  911. gui.drawtext(223, 24+8*o, ButtonNames[o], toRGBA(color), 0x0)
  912. end
  913.  
  914. for n,neuron in pairs(network.neurons) do
  915. cell = {}
  916. if n > Inputs and n <= MaxNodes then
  917. cell.x = 140
  918. cell.y = 40
  919. cell.value = neuron.value
  920. cells[n] = cell
  921. end
  922. end
  923.  
  924. for n=1,4 do
  925. for _,gene in pairs(genome.genes) do
  926. if gene.enabled then
  927. local c1 = cells[gene.into]
  928. local c2 = cells[gene.out]
  929. if gene.into > Inputs and gene.into <= MaxNodes then
  930. c1.x = 0.75*c1.x + 0.25*c2.x
  931. if c1.x >= c2.x then
  932. c1.x = c1.x - 40
  933. end
  934. if c1.x < 90 then
  935. c1.x = 90
  936. end
  937.  
  938. if c1.x > 220 then
  939. c1.x = 220
  940. end
  941. c1.y = 0.75*c1.y + 0.25*c2.y
  942.  
  943. end
  944. if gene.out > Inputs and gene.out <= MaxNodes then
  945. c2.x = 0.25*c1.x + 0.75*c2.x
  946. if c1.x >= c2.x then
  947. c2.x = c2.x + 40
  948. end
  949. if c2.x < 90 then
  950. c2.x = 90
  951. end
  952. if c2.x > 220 then
  953. c2.x = 220
  954. end
  955. c2.y = 0.25*c1.y + 0.75*c2.y
  956. end
  957. end
  958. end
  959. end
  960.  
  961. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  962. for n,cell in pairs(cells) do
  963. if n > Inputs or cell.value ~= 0 then
  964. local color = math.floor((cell.value+1)/2*256)
  965. if color > 255 then color = 255 end
  966. if color < 0 then color = 0 end
  967. local opacity = 0xFF000000
  968. if cell.value == 0 then
  969. opacity = 0x50000000
  970. end
  971. color = opacity + color*0x10000 + color*0x100 + color
  972. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  973. end
  974. end
  975. for _,gene in pairs(genome.genes) do
  976. if gene.enabled then
  977. local c1 = cells[gene.into]
  978. local c2 = cells[gene.out]
  979. local opacity = 0xA0000000
  980. if c1.value == 0 then
  981. opacity = 0x20000000
  982. end
  983.  
  984. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  985. if gene.weight > 0 then
  986. color = opacity + 0x8000 + 0x10000*color
  987. else
  988. color = opacity + 0x800000 + 0x100*color
  989. end
  990. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  991. end
  992. end
  993.  
  994. gui.drawBox(49,71,51,78,toRGBA(0x80FF0000),toRGBA(0x00000000))
  995.  
  996. if SHOW_MUTATION_RATES then
  997. local pos = 120
  998. for mutation,rate in pairs(genome.mutationRates) do
  999. gui.drawText(16, pos, mutation .. ": " .. rate, toRGBA(0xFF000000), 0x0)
  1000. pos = pos + 8
  1001. end
  1002. end
  1003. end
  1004.  
  1005. function writeFile(filename)
  1006. local file = io.open(filename, "w")
  1007. file:write(pool.generation .. "\n")
  1008. file:write(pool.maxFitness .. "\n")
  1009. file:write(#pool.species .. "\n")
  1010. for n,species in pairs(pool.species) do
  1011. file:write(species.topFitness .. "\n")
  1012. file:write(species.staleness .. "\n")
  1013. file:write(#species.genomes .. "\n")
  1014. for m,genome in pairs(species.genomes) do
  1015. file:write(genome.fitness .. "\n")
  1016. file:write(genome.maxneuron .. "\n")
  1017. for mutation,rate in pairs(genome.mutationRates) do
  1018. file:write(mutation .. "\n")
  1019. file:write(rate .. "\n")
  1020. end
  1021. file:write("done\n")
  1022.  
  1023. file:write(#genome.genes .. "\n")
  1024. for l,gene in pairs(genome.genes) do
  1025. file:write(gene.into .. " ")
  1026. file:write(gene.out .. " ")
  1027. file:write(gene.weight .. " ")
  1028. file:write(gene.innovation .. " ")
  1029. if(gene.enabled) then
  1030. file:write("1\n")
  1031. else
  1032. file:write("0\n")
  1033. end
  1034. end
  1035. end
  1036. end
  1037. file:close()
  1038. end
  1039.  
  1040. function savePool()
  1041. --local filename = forms.gettext(saveLoadFile)
  1042. local filename = SAVE_LOAD_FILE
  1043. writeFile(filename)
  1044. end
  1045.  
  1046. function loadFile(filename)
  1047. local file = io.open(filename, "r")
  1048. pool = newPool()
  1049. pool.generation = file:read("*number")
  1050. pool.maxFitness = file:read("*number")
  1051. --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1052. local numSpecies = file:read("*number")
  1053. for s=1,numSpecies do
  1054. local species = newSpecies()
  1055. table.insert(pool.species, species)
  1056. species.topFitness = file:read("*number")
  1057. species.staleness = file:read("*number")
  1058. local numGenomes = file:read("*number")
  1059. for g=1,numGenomes do
  1060. local genome = newGenome()
  1061. table.insert(species.genomes, genome)
  1062. genome.fitness = file:read("*number")
  1063. genome.maxneuron = file:read("*number")
  1064. local line = file:read("*line")
  1065. while line ~= "done" do
  1066. genome.mutationRates[line] = file:read("*number")
  1067. line = file:read("*line")
  1068. end
  1069. local numGenes = file:read("*number")
  1070. for n=1,numGenes do
  1071. local gene = newGene()
  1072. table.insert(genome.genes, gene)
  1073. local enabled
  1074.  
  1075. --not as elegant as reading number but at least it works
  1076. local t = {}
  1077. local a = file:read("*line")
  1078. for n in a:gmatch("%S+") do
  1079. table.insert(t, tonumber(n))
  1080. end
  1081. gene.into = t[1]
  1082. gene.out = t[2]
  1083. gene.weight = t[3]
  1084. gene.innovation = t[4]
  1085. enabled = t[5]
  1086. --old line
  1087. --gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1088.  
  1089. if enabled == 0 then
  1090. gene.enabled = false
  1091. else
  1092. gene.enabled = true
  1093. end
  1094.  
  1095. end
  1096. end
  1097. end
  1098. file:close()
  1099.  
  1100. while fitnessAlreadyMeasured() do
  1101. nextGenome()
  1102. end
  1103. initializeRun()
  1104. pool.currentFrame = pool.currentFrame + 1
  1105. end
  1106.  
  1107. function loadPool()
  1108. --local filename = forms.gettext(saveLoadFile)
  1109. --loadFile(filename)
  1110. loadFile(LOAD_FROM_FILE)
  1111. end
  1112.  
  1113. function playTop()
  1114. local maxfitness = 0
  1115. local maxs, maxg
  1116. for s,species in pairs(pool.species) do
  1117. for g,genome in pairs(species.genomes) do
  1118. if genome.fitness > maxfitness then
  1119. maxfitness = genome.fitness
  1120. maxs = s
  1121. maxg = g
  1122. end
  1123. end
  1124. end
  1125.  
  1126. pool.currentSpecies = maxs
  1127. pool.currentGenome = maxg
  1128. pool.maxFitness = maxfitness
  1129. --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1130. initializeRun()
  1131. pool.currentFrame = pool.currentFrame + 1
  1132. return
  1133. end
  1134.  
  1135.  
  1136. --function onExit()
  1137. -- forms.destroy(form)
  1138. --end
  1139.  
  1140.  
  1141.  
  1142. if LOAD_FROM_FILE then
  1143. loadPool()
  1144. end
  1145.  
  1146. writeFile("temp.pool")
  1147.  
  1148. --event.onexit(onExit)
  1149.  
  1150. --form = forms.newform(200, 260, "Fitness")
  1151. --maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1152. --showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1153. --showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1154. --restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1155. --saveButton = forms.button(form, "Save", savePool, 5, 102)
  1156. --loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1157. --saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1158. --saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1159. --playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1160. --hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1161.  
  1162.  
  1163. while true do
  1164. local backgroundColor = toRGBA(0xD0FFFFFF)
  1165. --if not forms.ischecked(hideBanner) then
  1166. gui.drawbox(0, 0, 200, 30, backgroundColor, backgroundColor)
  1167. --end
  1168.  
  1169. if pool.maxcounter == nil then
  1170. pool.maxcounter = 0
  1171. emu.print("caught missing variable")
  1172. end
  1173.  
  1174. --local species = pool.species[pool.currentSpecies]
  1175. --local genome = species.genomes[pool.currentGenome]
  1176.  
  1177. local genome = getCurrentGenome()
  1178.  
  1179. --if forms.ischecked(showNetwork) then
  1180. -- displayGenome(genome)
  1181. --end
  1182.  
  1183. if pool.currentFrame%5 == 0 then
  1184. evaluateCurrent()
  1185. end
  1186.  
  1187. joypad.set(1, controller)
  1188.  
  1189. getPositions()
  1190.  
  1191. if mariostate == 7 then
  1192. mariopipe = true
  1193. end
  1194. if mariopipe == true and mariostate ~= 7 and rightmost > marioX then
  1195. offset = rightmost - marioX
  1196. mariopipe = false
  1197. end
  1198.  
  1199. if marioY + memory.readbyte(0xB5)*255 > 512 then --added the fallen into a hole variable that will activate when mario goes below screen (or way above the screen)
  1200. mariohole = true
  1201. end
  1202. if marioX + offset > rightmost and not mariohole then
  1203. rightmost = marioX + offset
  1204. timeout = TimeoutConstant
  1205. elseif marioX + offset <= rightmost or mariohole then
  1206. timebonus = timebonus + 1/2
  1207. end
  1208.  
  1209. timeout = timeout - 1
  1210.  
  1211.  
  1212. local timeoutBonus = pool.currentFrame / 4
  1213. if timeout + timeoutBonus <= 0 then
  1214. local fitness = rightmost - pool.currentFrame / 2 + timebonus -39.5
  1215. if fitness == 0 then
  1216. fitness = -1
  1217. end
  1218.  
  1219. genome.fitness = fitness
  1220.  
  1221.  
  1222. if fitness > pool.maxFitness then
  1223. pool.maxFitness = fitness
  1224. --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1225. --writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1226. end
  1227.  
  1228. if fitness > pool.maxFitness - 30 then
  1229. pool.maxcounter = pool.maxcounter + 1
  1230. end
  1231.  
  1232. emu.print("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1233. pool.currentSpecies = 1
  1234. pool.currentGenome = 1
  1235. while fitnessAlreadyMeasured() do
  1236. mariohole = false --reset the fallen into hole variable
  1237. mariopipe = false
  1238. offset = 0
  1239. timebonus = 0
  1240. nextGenome()
  1241. end
  1242. initializeRun()
  1243. end
  1244.  
  1245. local measured = 0
  1246. local total = 0
  1247. for _,species in pairs(pool.species) do
  1248. for _,genome in pairs(species.genomes) do
  1249. total = total + 1
  1250. if genome.fitness ~= 0 then
  1251. measured = measured + 1
  1252. end
  1253. end
  1254. end
  1255. --if not forms.ischecked(hideBanner) then
  1256. gui.drawtext(0, 9, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", toRGBA(0xFF000000), 0x0)
  1257. gui.drawtext(0, 18, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 + timebonus -39.5), toRGBA(0xFF000000), 0x0) -- changed - (timeout + timeoutBonus)*2/3 into timebonus
  1258. gui.drawtext(75, 18, "Max Fitness:" .. math.floor(pool.maxFitness).. " ".. "(" .. math.ceil(pool.maxcounter/Population*100) .. "%)", toRGBA(0xFF000000), 0x0)
  1259. --gui.drawtext(100, 18, "(" .. math.ceil(pool.maxcounter/Population*100) .. "%)", toRGBA(0xFF000000), 0x2) --draws the % that reached the max fitness - 30
  1260. --end
  1261.  
  1262. pool.currentFrame = pool.currentFrame + 1
  1263.  
  1264. emu.frameadvance();
  1265. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement