Advertisement
Guest User

MarI/0 changed script

a guest
Feb 23rd, 2018
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 31.29 KB | None | 0 0
  1. -- MarI/O by SethBling, changed by Akisame
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8. Filename = "DP1.state"
  9. ButtonNames = {
  10. "A",
  11. "B",
  12. "X",
  13. "Y",
  14. "Up",
  15. "Down",
  16. "Left",
  17. "Right",
  18. }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20. Filename = "SMB1-1.state"
  21. ButtonNames = {
  22. "A",
  23. "B",
  24. "Up",
  25. "Down",
  26. "Left",
  27. "Right",
  28. }
  29. end
  30.  
  31. BoxRadius = 6
  32. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  33.  
  34. Inputs = InputSize+1
  35. Outputs = #ButtonNames
  36.  
  37. Population = 300
  38. DeltaDisjoint = 2.0
  39. DeltaWeights = 0.4
  40. DeltaThreshold = 1.0
  41.  
  42. StaleSpecies = 25 --increased slightly
  43.  
  44. MutateConnectionsChance = 0.25
  45. PerturbChance = 0.90
  46. CrossoverChance = 0.75
  47. LinkMutationChance = 2.0
  48. NodeMutationChance = 0.50
  49. BiasMutationChance = 0.40
  50. StepSize = 0.1
  51. DisableMutationChance = 0.4
  52. EnableMutationChance = 0.2
  53.  
  54. TimeoutConstant = 90 --set to a higher value than the original but see for yourself what is acceptable
  55.  
  56. MaxNodes = 1000000
  57.  
  58. function getPositions()
  59. if gameinfo.getromname() == "Super Mario World (USA)" then
  60. marioX = memory.read_s16_le(0x94)
  61. marioY = memory.read_s16_le(0x96)
  62.  
  63.  
  64. local layer1x = memory.read_s16_le(0x1A);
  65. local layer1y = memory.read_s16_le(0x1C);
  66.  
  67. screenX = marioX-layer1x
  68. screenY = marioY-layer1y
  69. elseif gameinfo.getromname() == "Super Mario Bros." then
  70. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  71. marioY = memory.readbyte(0x03B8)+16
  72. --this is a new addition
  73. marioYextended = memory.readbyte(0x0B5)
  74.  
  75. screenX = memory.readbyte(0x03AD)
  76. screenY = memory.readbyte(0x03B8)
  77. end
  78. end
  79.  
  80. function getTile(dx, dy)
  81. if gameinfo.getromname() == "Super Mario World (USA)" then
  82. x = math.floor((marioX+dx+8)/16)
  83. y = math.floor((marioY+dy)/16)
  84.  
  85. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  86. elseif gameinfo.getromname() == "Super Mario Bros." then
  87. local x = marioX + dx + 8
  88. local y = marioY + dy - 16
  89. local page = math.floor(x/256)%2
  90.  
  91. local subx = math.floor((x%256)/16)
  92. local suby = math.floor((y - 32)/16)
  93. local addr = 0x500 + page*13*16+suby*16+subx
  94.  
  95. if suby >= 13 or suby < 0 then
  96. return 0
  97. end
  98.  
  99. if memory.readbyte(addr) ~= 0 then
  100. return 1
  101. else
  102. return 0
  103. end
  104. end
  105. end
  106.  
  107. function getSprites()
  108. if gameinfo.getromname() == "Super Mario World (USA)" then
  109. local sprites = {}
  110. for slot=0,11 do
  111. local status = memory.readbyte(0x14C8+slot)
  112. if status ~= 0 then
  113. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  114. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  115. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  116. end
  117. end
  118.  
  119. return sprites
  120. elseif gameinfo.getromname() == "Super Mario Bros." then
  121. local sprites = {}
  122. for slot=0,4 do
  123. local enemy = memory.readbyte(0xF+slot)
  124. if enemy ~= 0 then
  125. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  126. local ey = memory.readbyte(0xCF + slot)+24
  127. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  128. end
  129. end
  130.  
  131. return sprites
  132. end
  133. end
  134.  
  135. function getExtendedSprites()
  136. if gameinfo.getromname() == "Super Mario World (USA)" then
  137. local extended = {}
  138. for slot=0,11 do
  139. local number = memory.readbyte(0x170B+slot)
  140. if number ~= 0 then
  141. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  142. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  143. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  144. end
  145. end
  146.  
  147. return extended
  148. elseif gameinfo.getromname() == "Super Mario Bros." then
  149. return {}
  150. end
  151. end
  152.  
  153. function getInputs()
  154. getPositions()
  155.  
  156. sprites = getSprites()
  157. extended = getExtendedSprites()
  158.  
  159. local inputs = {}
  160.  
  161. for dy=-BoxRadius*16,BoxRadius*16,16 do
  162. for dx=-BoxRadius*16,BoxRadius*16,16 do
  163. inputs[#inputs+1] = 0
  164.  
  165. tile = getTile(dx, dy)
  166. if tile == 1 and marioY+dy < 0x1B0 then
  167. inputs[#inputs] = 1
  168. end
  169.  
  170. for i = 1,#sprites do
  171. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  172. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  173. if distx <= 8 and disty <= 8 then
  174. inputs[#inputs] = -1
  175. end
  176. end
  177.  
  178. for i = 1,#extended do
  179. distx = math.abs(extended[i]["x"] - (marioX+dx))
  180. disty = math.abs(extended[i]["y"] - (marioY+dy))
  181. if distx < 8 and disty < 8 then
  182. inputs[#inputs] = -1
  183. end
  184. end
  185. end
  186. end
  187.  
  188. --mariovx = memory.read_s8(0x7B)
  189. --mariovy = memory.read_s8(0x7D)
  190.  
  191. return inputs
  192. end
  193.  
  194. function sigmoid(x)
  195. return 2/(1+math.exp(-4.9*x))-1
  196. end
  197.  
  198. function newInnovation()
  199. pool.innovation = pool.innovation + 1
  200. return pool.innovation
  201. end
  202.  
  203. function newPool()
  204. local pool = {}
  205. pool.species = {}
  206. pool.generation = 0
  207. pool.innovation = Outputs
  208. pool.currentSpecies = 1
  209. pool.currentGenome = 1
  210. pool.currentFrame = 0
  211. pool.maxFitness = 0
  212. pool.species2 = {}
  213. return pool
  214. end
  215.  
  216. function newSpecies()
  217. local species = {}
  218. species.topFitness = 0
  219. species.staleness = 0
  220. species.genomes = {}
  221. species.averageFitness = 0
  222.  
  223. return species
  224. end
  225.  
  226. function newGenome()
  227. local genome = {}
  228. genome.genes = {}
  229. genome.fitness = 0
  230. genome.adjustedFitness = 0
  231. genome.network = {}
  232. genome.maxneuron = 0
  233. genome.globalRank = 0
  234. genome.mutationRates = {}
  235. genome.mutationRates["connections"] = MutateConnectionsChance
  236. genome.mutationRates["link"] = LinkMutationChance
  237. genome.mutationRates["bias"] = BiasMutationChance
  238. genome.mutationRates["node"] = NodeMutationChance
  239. genome.mutationRates["enable"] = EnableMutationChance
  240. genome.mutationRates["disable"] = DisableMutationChance
  241. genome.mutationRates["step"] = StepSize
  242.  
  243. return genome
  244. end
  245.  
  246. function copyGenome(genome)
  247. local genome2 = newGenome()
  248. for g=1,#genome.genes do
  249. table.insert(genome2.genes, copyGene(genome.genes[g]))
  250. end
  251. genome2.maxneuron = genome.maxneuron
  252. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  253. genome2.mutationRates["link"] = genome.mutationRates["link"]
  254. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  255. genome2.mutationRates["node"] = genome.mutationRates["node"]
  256. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  257. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  258.  
  259. return genome2
  260. end
  261.  
  262. function basicGenome()
  263. local genome = newGenome()
  264. local innovation = 1
  265.  
  266. genome.maxneuron = Inputs
  267. mutate(genome)
  268.  
  269. return genome
  270. end
  271.  
  272. function newGene()
  273. local gene = {}
  274. gene.into = 0
  275. gene.out = 0
  276. gene.weight = 0.0
  277. gene.enabled = true
  278. gene.innovation = 0
  279.  
  280. return gene
  281. end
  282.  
  283. function copyGene(gene)
  284. local gene2 = newGene()
  285. gene2.into = gene.into
  286. gene2.out = gene.out
  287. gene2.weight = gene.weight
  288. gene2.enabled = gene.enabled
  289. gene2.innovation = gene.innovation
  290.  
  291. return gene2
  292. end
  293.  
  294. function newNeuron()
  295. local neuron = {}
  296. neuron.incoming = {}
  297. neuron.value = 0.0
  298.  
  299. return neuron
  300. end
  301.  
  302. function generateNetwork(genome)
  303. local network = {}
  304. network.neurons = {}
  305.  
  306. for i=1,Inputs do
  307. network.neurons[i] = newNeuron()
  308. end
  309.  
  310. for o=1,Outputs do
  311. network.neurons[MaxNodes+o] = newNeuron()
  312. end
  313.  
  314. table.sort(genome.genes, function (a,b)
  315. return (a.out < b.out)
  316. end)
  317. for i=1,#genome.genes do
  318. local gene = genome.genes[i]
  319. if gene.enabled then
  320. if network.neurons[gene.out] == nil then
  321. network.neurons[gene.out] = newNeuron()
  322. end
  323. local neuron = network.neurons[gene.out]
  324. table.insert(neuron.incoming, gene)
  325. if network.neurons[gene.into] == nil then
  326. network.neurons[gene.into] = newNeuron()
  327. end
  328. end
  329. end
  330.  
  331. genome.network = network
  332. end
  333.  
  334. function evaluateNetwork(network, inputs)
  335. table.insert(inputs, 1)
  336. if #inputs ~= Inputs then
  337. console.writeline("Incorrect number of neural network inputs.")
  338. return {}
  339. end
  340.  
  341. for i=1,Inputs do
  342. network.neurons[i].value = inputs[i]
  343. end
  344.  
  345. for _,neuron in pairs(network.neurons) do
  346. local sum = 0
  347. for j = 1,#neuron.incoming do
  348. local incoming = neuron.incoming[j]
  349. local other = network.neurons[incoming.into]
  350. sum = sum + incoming.weight * other.value
  351. end
  352.  
  353. if #neuron.incoming > 0 then
  354. neuron.value = sigmoid(sum)
  355. end
  356. end
  357.  
  358. local outputs = {}
  359. for o=1,Outputs do
  360. local button = "P1 " .. ButtonNames[o]
  361. if network.neurons[MaxNodes+o].value > 0 then
  362. outputs[button] = true
  363. else
  364. outputs[button] = false
  365. end
  366. end
  367.  
  368. return outputs
  369. end
  370.  
  371. function crossover(g1, g2)
  372. -- Make sure g1 is the higher fitness genome
  373. if g2.fitness > g1.fitness then
  374. tempg = g1
  375. g1 = g2
  376. g2 = tempg
  377. end
  378.  
  379. local child = newGenome()
  380.  
  381. local innovations2 = {}
  382. for i=1,#g2.genes do
  383. local gene = g2.genes[i]
  384. innovations2[gene.innovation] = gene
  385. end
  386.  
  387. for i=1,#g1.genes do
  388. local gene1 = g1.genes[i]
  389. local gene2 = innovations2[gene1.innovation]
  390. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  391. table.insert(child.genes, copyGene(gene2))
  392. else
  393. table.insert(child.genes, copyGene(gene1))
  394. end
  395. end
  396.  
  397. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  398.  
  399. for mutation,rate in pairs(g1.mutationRates) do
  400. child.mutationRates[mutation] = rate
  401. end
  402.  
  403. return child
  404. end
  405.  
  406. function randomNeuron(genes, nonInput)
  407. local neurons = {}
  408. if not nonInput then
  409. for i=1,Inputs do
  410. neurons[i] = true
  411. end
  412. end
  413. for o=1,Outputs do
  414. neurons[MaxNodes+o] = true
  415. end
  416. for i=1,#genes do
  417. if (not nonInput) or genes[i].into > Inputs then
  418. neurons[genes[i].into] = true
  419. end
  420. if (not nonInput) or genes[i].out > Inputs then
  421. neurons[genes[i].out] = true
  422. end
  423. end
  424.  
  425. local count = 0
  426. for _,_ in pairs(neurons) do
  427. count = count + 1
  428. end
  429. local n = math.random(1, count)
  430.  
  431. for k,v in pairs(neurons) do
  432. n = n-1
  433. if n == 0 then
  434. return k
  435. end
  436. end
  437.  
  438. return 0
  439. end
  440.  
  441. function containsLink(genes, link)
  442. for i=1,#genes do
  443. local gene = genes[i]
  444. if gene.into == link.into and gene.out == link.out then
  445. return true
  446. end
  447. end
  448. end
  449.  
  450. function pointMutate(genome)
  451. local step = genome.mutationRates["step"]
  452.  
  453. for i=1,#genome.genes do
  454. local gene = genome.genes[i]
  455. if math.random() < PerturbChance then
  456. gene.weight = gene.weight + math.random() * step*2 - step
  457. else
  458. gene.weight = math.random()*4-2
  459. end
  460. end
  461. end
  462.  
  463. function linkMutate(genome, forceBias)
  464. local neuron1 = randomNeuron(genome.genes, false)
  465. local neuron2 = randomNeuron(genome.genes, true)
  466.  
  467. local newLink = newGene()
  468. if neuron1 <= Inputs and neuron2 <= Inputs then
  469. --Both input nodes
  470. return
  471. end
  472. if neuron2 <= Inputs then
  473. -- Swap output and input
  474. local temp = neuron1
  475. neuron1 = neuron2
  476. neuron2 = temp
  477. end
  478.  
  479. newLink.into = neuron1
  480. newLink.out = neuron2
  481. if forceBias then
  482. newLink.into = Inputs
  483. end
  484.  
  485. if containsLink(genome.genes, newLink) then
  486. return
  487. end
  488. newLink.innovation = newInnovation()
  489. newLink.weight = math.random()*4-2
  490.  
  491. table.insert(genome.genes, newLink)
  492. end
  493.  
  494. function nodeMutate(genome)
  495. if #genome.genes == 0 then
  496. return
  497. end
  498.  
  499. genome.maxneuron = genome.maxneuron + 1
  500.  
  501. local gene = genome.genes[math.random(1,#genome.genes)]
  502. if not gene.enabled then
  503. return
  504. end
  505. gene.enabled = false
  506.  
  507. local gene1 = copyGene(gene)
  508. gene1.out = genome.maxneuron
  509. gene1.weight = 1.0
  510. gene1.innovation = newInnovation()
  511. gene1.enabled = true
  512. table.insert(genome.genes, gene1)
  513.  
  514. local gene2 = copyGene(gene)
  515. gene2.into = genome.maxneuron
  516. gene2.innovation = newInnovation()
  517. gene2.enabled = true
  518. table.insert(genome.genes, gene2)
  519. end
  520.  
  521. function enableDisableMutate(genome, enable)
  522. local candidates = {}
  523. for _,gene in pairs(genome.genes) do
  524. if gene.enabled == not enable then
  525. table.insert(candidates, gene)
  526. end
  527. end
  528.  
  529. if #candidates == 0 then
  530. return
  531. end
  532.  
  533. local gene = candidates[math.random(1,#candidates)]
  534. gene.enabled = not gene.enabled
  535. end
  536.  
  537. function mutate(genome)
  538. for mutation,rate in pairs(genome.mutationRates) do
  539. if math.random(1,2) == 1 then
  540. genome.mutationRates[mutation] = 0.95*rate
  541. else
  542. genome.mutationRates[mutation] = 1.05263*rate
  543. end
  544. end
  545.  
  546. if math.random() < genome.mutationRates["connections"] then
  547. pointMutate(genome)
  548. end
  549.  
  550. local p = genome.mutationRates["link"]
  551. while p > 0 do
  552. if math.random() < p then
  553. linkMutate(genome, false)
  554. end
  555. p = p - 1
  556. end
  557.  
  558. p = genome.mutationRates["bias"]
  559. while p > 0 do
  560. if math.random() < p then
  561. linkMutate(genome, true)
  562. end
  563. p = p - 1
  564. end
  565.  
  566. p = genome.mutationRates["node"]
  567. while p > 0 do
  568. if math.random() < p then
  569. nodeMutate(genome)
  570. end
  571. p = p - 1
  572. end
  573.  
  574. p = genome.mutationRates["enable"]
  575. while p > 0 do
  576. if math.random() < p then
  577. enableDisableMutate(genome, true)
  578. end
  579. p = p - 1
  580. end
  581.  
  582. p = genome.mutationRates["disable"]
  583. while p > 0 do
  584. if math.random() < p then
  585. enableDisableMutate(genome, false)
  586. end
  587. p = p - 1
  588. end
  589. end
  590.  
  591. function disjoint(genes1, genes2)
  592. local i1 = {}
  593. for i = 1,#genes1 do
  594. local gene = genes1[i]
  595. i1[gene.innovation] = true
  596. end
  597.  
  598. local i2 = {}
  599. for i = 1,#genes2 do
  600. local gene = genes2[i]
  601. i2[gene.innovation] = true
  602. end
  603.  
  604. local disjointGenes = 0
  605. for i = 1,#genes1 do
  606. local gene = genes1[i]
  607. if not i2[gene.innovation] then
  608. disjointGenes = disjointGenes+1
  609. end
  610. end
  611.  
  612. for i = 1,#genes2 do
  613. local gene = genes2[i]
  614. if not i1[gene.innovation] then
  615. disjointGenes = disjointGenes+1
  616. end
  617. end
  618.  
  619. local n = math.max(#genes1, #genes2)
  620.  
  621. return disjointGenes / n
  622. end
  623.  
  624. function weights(genes1, genes2)
  625. local i2 = {}
  626. for i = 1,#genes2 do
  627. local gene = genes2[i]
  628. i2[gene.innovation] = gene
  629. end
  630.  
  631. local sum = 0
  632. local coincident = 0
  633. for i = 1,#genes1 do
  634. local gene = genes1[i]
  635. if i2[gene.innovation] ~= nil then
  636. local gene2 = i2[gene.innovation]
  637. sum = sum + math.abs(gene.weight - gene2.weight)
  638. coincident = coincident + 1
  639. end
  640. end
  641.  
  642. return sum / coincident
  643. end
  644.  
  645. function sameSpecies(genome1, genome2)
  646. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  647. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  648. return dd + dw < DeltaThreshold
  649. end
  650.  
  651. function rankGlobally()
  652. local global = {}
  653. for s = 1,#pool.species do
  654. local species = pool.species[s]
  655. for g = 1,#species.genomes do
  656. table.insert(global, species.genomes[g])
  657. end
  658. end
  659. table.sort(global, function (a,b)
  660. return (a.fitness < b.fitness)
  661. end)
  662.  
  663. for g=1,#global do
  664. global[g].globalRank = g
  665. end
  666. end
  667.  
  668. function calculateAverageFitness(species)
  669. local total = 0
  670.  
  671. for g=1,#species.genomes do
  672. local genome = species.genomes[g]
  673. total = total + genome.globalRank
  674. end
  675.  
  676. species.averageFitness = total / #species.genomes
  677. end
  678.  
  679. function totalAverageFitness()
  680. local total = 0
  681. for s = 1,#pool.species do
  682. local species = pool.species[s]
  683. total = total + species.averageFitness
  684. end
  685.  
  686. return total
  687. end
  688.  
  689. --new random function
  690. function generatechance()
  691. local chance = math.random(1, 1000) / 1000 -- 0.001 to 1 inclusive.
  692. return chance
  693. end
  694.  
  695. function cullSpecies(cutToOne)
  696. for s = 1,#pool.species do
  697. local species = pool.species[s]
  698.  
  699. table.sort(species.genomes, function (a,b)
  700. return (a.fitness > b.fitness)
  701. end)
  702.  
  703. local remaining = math.ceil(#species.genomes) --original /2
  704. if cutToOne then
  705. remaining = 1
  706. --console.writeline("cull to one")
  707. end
  708. while #species.genomes > remaining do
  709. table.remove(species.genomes)
  710. end
  711. if cutToOne then
  712. table.insert(pool.species2, species) --I was planning to use this later on.
  713. end
  714. end
  715. --new culling system here
  716. if not cutToOne then
  717. local chance = 0
  718. local pmax
  719. local nmaxfitness = 0
  720. local fitchance
  721. local fitmaxchance
  722. local killedgenome = {}
  723. local killedspecies = {}
  724. local survived = {}
  725. local survivedspecies = {}
  726.  
  727. for s,species in pairs(pool.species) do
  728. for g,genome in pairs(species.genomes) do
  729. if pool.maxFitness < 1 then
  730. console.writeline("well. this sucks. Your species have no chance of survival.")
  731. return
  732. else
  733. if pool.maxFitness > 500 then -- Culling based on chance as well as difference from max fitness.
  734. nmaxfitness = pool.maxFitness - 450
  735. fitchance = genome.fitness - nmaxfitness
  736. pmax = fitchance / 450
  737. else
  738. pmax = genome.fitness / pool.maxFitness
  739. end
  740. chance = generatechance()
  741. if pmax > chance then
  742. table.insert(survived, species)
  743. end
  744. end
  745. end
  746. end
  747. pool.species = survived
  748. end
  749. end
  750.  
  751. function breedChild(species)
  752. local child = {}
  753. if math.random() < CrossoverChance then
  754. g1 = species.genomes[math.random(1, #species.genomes)] --I might change this a bit in the future
  755. g2 = species.genomes[math.random(1, #species.genomes)]
  756. child = crossover(g1, g2)
  757. else
  758. g = species.genomes[math.random(1, #species.genomes)]
  759. child = copyGenome(g)
  760. end
  761.  
  762. mutate(child)
  763.  
  764. return child
  765. end
  766.  
  767. function removeStaleSpecies()
  768. local survived = {}
  769.  
  770. for s = 1,#pool.species do
  771. local species = pool.species[s]
  772.  
  773. table.sort(species.genomes, function (a,b)
  774. return (a.fitness > b.fitness)
  775. end)
  776.  
  777. if species.genomes[1].fitness > species.topFitness then
  778. species.topFitness = species.genomes[1].fitness
  779. species.staleness = 0
  780. else
  781. species.staleness = species.staleness + 1
  782. end
  783. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness - 45 then --prevented removal of all species but one when mario is stuck.
  784. table.insert(survived, species)
  785. end
  786. end
  787.  
  788. pool.species = survived
  789. end
  790.  
  791. function removeWeakSpecies()
  792. local survived = {}
  793.  
  794. local sum = totalAverageFitness()
  795. for s = 1,#pool.species do
  796. local species = pool.species[s]
  797. breed = math.floor(species.averageFitness / sum * Population)
  798. if breed >= 1 then
  799. table.insert(survived, species)
  800. end
  801.  
  802. end
  803.  
  804. pool.species = survived
  805. end
  806.  
  807.  
  808. function addToSpecies(child)
  809. local foundSpecies = false
  810. if #pool.species > 0 then
  811. for s=1,#pool.species do
  812. local species = pool.species[s]
  813. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  814. table.insert(species.genomes, child)
  815. foundSpecies = true
  816. end
  817. end
  818. end
  819. if not foundSpecies then
  820. local childSpecies = newSpecies()
  821. table.insert(childSpecies.genomes, child)
  822. table.insert(pool.species, childSpecies)
  823. end
  824. end
  825.  
  826. function newGeneration()
  827. cullSpecies(false) -- make a selection of surviving species for breeding
  828. rankGlobally()
  829. removeStaleSpecies()
  830. rankGlobally()
  831. for s = 1,#pool.species do
  832. local species = pool.species[s]
  833. calculateAverageFitness(species)
  834. end
  835. removeWeakSpecies()
  836. local sum = totalAverageFitness()
  837. local children = {}
  838. for s = 1,#pool.species do
  839. local species = pool.species[s]
  840. breed = math.floor(species.averageFitness / sum * Population) - 4
  841. for i=1,breed do
  842. table.insert(children, breedChild(species))
  843. end
  844. end
  845. cullSpecies(true) -- Cull all but the top member of each species
  846.  
  847.  
  848. local counter = 1
  849. local g
  850. local champion
  851. local speciessort = {}
  852. for s = 1,#pool.species do
  853. local tempspecies = pool.species[s]
  854. for g = 1,#tempspecies.genomes do
  855. table.insert(speciessort, tempspecies.genomes[g])
  856. end
  857. end
  858. table.sort(speciessort, function (a,b)
  859. return (a.fitness < b.fitness)
  860. end)
  861. while #children > Population - 15 do --I do this to ensure that the top 15 genomes are included
  862. table.remove(children)
  863. end
  864. while #children < Population do
  865. local species = pool.species[math.random(1, #pool.species)]
  866. if counter < #speciessort then
  867. g = speciessort[counter]
  868. champion = copyGenome(g)
  869. table.insert(children, champion)
  870. counter = counter + 1
  871. else
  872. table.insert(children, breedChild(species))
  873. end
  874. end
  875. pool.species = {}
  876. for c=1,#children do
  877. local child = children[c]
  878. addToSpecies(child)
  879. end
  880. pool.generation = pool.generation + 1
  881.  
  882. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  883. end
  884.  
  885.  
  886. function initializePool()
  887. pool = newPool()
  888.  
  889. for i=1,Population do
  890. basic = basicGenome()
  891. addToSpecies(basic)
  892. end
  893.  
  894. initializeRun()
  895. end
  896.  
  897. function clearJoypad()
  898. controller = {}
  899. for b = 1,#ButtonNames do
  900. controller["P1 " .. ButtonNames[b]] = false
  901. end
  902. joypad.set(controller)
  903. end
  904.  
  905. function initializeRun()
  906. savestate.load(Filename);
  907. rightmost = 0
  908. pool.currentFrame = 0
  909. timeout = TimeoutConstant
  910. clearJoypad()
  911.  
  912. local species = pool.species[pool.currentSpecies]
  913. local genome = species.genomes[pool.currentGenome]
  914. generateNetwork(genome)
  915. evaluateCurrent()
  916. end
  917.  
  918. function evaluateCurrent()
  919. local species = pool.species[pool.currentSpecies]
  920. local genome = species.genomes[pool.currentGenome]
  921.  
  922. inputs = getInputs()
  923. controller = evaluateNetwork(genome.network, inputs)
  924.  
  925. if controller["P1 Left"] and controller["P1 Right"] then
  926. controller["P1 Left"] = false
  927. controller["P1 Right"] = false
  928. end
  929. if controller["P1 Up"] and controller["P1 Down"] then
  930. controller["P1 Up"] = false
  931. controller["P1 Down"] = false
  932. end
  933.  
  934. joypad.set(controller)
  935. end
  936.  
  937. if pool == nil then
  938. initializePool()
  939. end
  940.  
  941.  
  942. function nextGenome()
  943. pool.currentGenome = pool.currentGenome + 1
  944.  
  945. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  946. pool.currentGenome = 1
  947. pool.currentSpecies = pool.currentSpecies+1
  948. if pool.currentSpecies > #pool.species then
  949. newGeneration()
  950. pool.currentSpecies = 1
  951. end
  952. end
  953. end
  954.  
  955. function fitnessAlreadyMeasured()
  956. local species = pool.species[pool.currentSpecies]
  957. local genome = species.genomes[pool.currentGenome]
  958.  
  959. return genome.fitness ~= 0
  960. end
  961.  
  962. function displayGenome(genome)
  963. local network = genome.network
  964. local cells = {}
  965. local i = 1
  966. local cell = {}
  967. for dy=-BoxRadius,BoxRadius do
  968. for dx=-BoxRadius,BoxRadius do
  969. cell = {}
  970. cell.x = 50+5*dx
  971. cell.y = 70+5*dy
  972. cell.value = network.neurons[i].value
  973. cells[i] = cell
  974. i = i + 1
  975. end
  976. end
  977. local biasCell = {}
  978. biasCell.x = 80
  979. biasCell.y = 110
  980. biasCell.value = network.neurons[Inputs].value
  981. cells[Inputs] = biasCell
  982.  
  983. for o = 1,Outputs do
  984. cell = {}
  985. cell.x = 220
  986. cell.y = 30 + 8 * o
  987. cell.value = network.neurons[MaxNodes + o].value
  988. cells[MaxNodes+o] = cell
  989. local color
  990. if cell.value > 0 then
  991. color = 0xFF0000FF
  992. else
  993. color = 0xFF000000
  994. end
  995. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  996. end
  997.  
  998. for n,neuron in pairs(network.neurons) do
  999. cell = {}
  1000. if n > Inputs and n <= MaxNodes then
  1001. cell.x = 140
  1002. cell.y = 40
  1003. cell.value = neuron.value
  1004. cells[n] = cell
  1005. end
  1006. end
  1007.  
  1008. for n=1,4 do
  1009. for _,gene in pairs(genome.genes) do
  1010. if gene.enabled then
  1011. local c1 = cells[gene.into]
  1012. local c2 = cells[gene.out]
  1013. if gene.into > Inputs and gene.into <= MaxNodes then
  1014. c1.x = 0.75*c1.x + 0.25*c2.x
  1015. if c1.x >= c2.x then
  1016. c1.x = c1.x - 40
  1017. end
  1018. if c1.x < 90 then
  1019. c1.x = 90
  1020. end
  1021.  
  1022. if c1.x > 220 then
  1023. c1.x = 220
  1024. end
  1025. c1.y = 0.75*c1.y + 0.25*c2.y
  1026.  
  1027. end
  1028. if gene.out > Inputs and gene.out <= MaxNodes then
  1029. c2.x = 0.25*c1.x + 0.75*c2.x
  1030. if c1.x >= c2.x then
  1031. c2.x = c2.x + 40
  1032. end
  1033. if c2.x < 90 then
  1034. c2.x = 90
  1035. end
  1036. if c2.x > 220 then
  1037. c2.x = 220
  1038. end
  1039. c2.y = 0.25*c1.y + 0.75*c2.y
  1040. end
  1041. end
  1042. end
  1043. end
  1044.  
  1045. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  1046. for n,cell in pairs(cells) do
  1047. if n > Inputs or cell.value ~= 0 then
  1048. local color = math.floor((cell.value+1)/2*256)
  1049. if color > 255 then color = 255 end
  1050. if color < 0 then color = 0 end
  1051. local opacity = 0xFF000000
  1052. if cell.value == 0 then
  1053. opacity = 0x50000000
  1054. end
  1055. color = opacity + color*0x10000 + color*0x100 + color
  1056. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  1057. end
  1058. end
  1059. for _,gene in pairs(genome.genes) do
  1060. if gene.enabled then
  1061. local c1 = cells[gene.into]
  1062. local c2 = cells[gene.out]
  1063. local opacity = 0xA0000000
  1064. if c1.value == 0 then
  1065. opacity = 0x20000000
  1066. end
  1067.  
  1068. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  1069. if gene.weight > 0 then
  1070. color = opacity + 0x8000 + 0x10000*color
  1071. else
  1072. color = opacity + 0x800000 + 0x100*color
  1073. end
  1074. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  1075. end
  1076. end
  1077.  
  1078. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1079.  
  1080. if forms.ischecked(showMutationRates) then
  1081. local pos = 100
  1082. for mutation,rate in pairs(genome.mutationRates) do
  1083. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1084. pos = pos + 8
  1085. end
  1086. end
  1087. end
  1088.  
  1089. function writeFile(filename)
  1090. local file = io.open(filename, "w")
  1091. file:write(pool.generation .. "\n")
  1092. file:write(pool.maxFitness .. "\n")
  1093. file:write(#pool.species .. "\n")
  1094. for n,species in pairs(pool.species) do
  1095. file:write(species.topFitness .. "\n")
  1096. file:write(species.staleness .. "\n")
  1097. file:write(#species.genomes .. "\n")
  1098. for m,genome in pairs(species.genomes) do
  1099. file:write(genome.fitness .. "\n")
  1100. file:write(genome.maxneuron .. "\n")
  1101. for mutation,rate in pairs(genome.mutationRates) do
  1102. file:write(mutation .. "\n")
  1103. file:write(rate .. "\n")
  1104. end
  1105. file:write("done\n")
  1106.  
  1107. file:write(#genome.genes .. "\n")
  1108. for l,gene in pairs(genome.genes) do
  1109. file:write(gene.into .. " ")
  1110. file:write(gene.out .. " ")
  1111. file:write(gene.weight .. " ")
  1112. file:write(gene.innovation .. " ")
  1113. if(gene.enabled) then
  1114. file:write("1\n")
  1115. else
  1116. file:write("0\n")
  1117. end
  1118. end
  1119. end
  1120. end
  1121. file:close()
  1122. end
  1123.  
  1124. function savePool()
  1125. local filename = forms.gettext(saveLoadFile)
  1126. writeFile(filename)
  1127. end
  1128.  
  1129. function loadFile(filename)
  1130. local file = io.open(filename, "r")
  1131. pool = newPool()
  1132. pool.generation = file:read("*number")
  1133. pool.maxFitness = file:read("*number")
  1134. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1135. local numSpecies = file:read("*number")
  1136. for s=1,numSpecies do
  1137. local species = newSpecies()
  1138. table.insert(pool.species, species)
  1139. species.topFitness = file:read("*number")
  1140. species.staleness = file:read("*number")
  1141. local numGenomes = file:read("*number")
  1142. for g=1,numGenomes do
  1143. local genome = newGenome()
  1144. table.insert(species.genomes, genome)
  1145. genome.fitness = file:read("*number")
  1146. genome.maxneuron = file:read("*number")
  1147. local line = file:read("*line")
  1148. while line ~= "done" do
  1149. genome.mutationRates[line] = file:read("*number")
  1150. line = file:read("*line")
  1151. end
  1152. local numGenes = file:read("*number")
  1153. for n=1,numGenes do
  1154. local gene = newGene()
  1155. table.insert(genome.genes, gene)
  1156. local enabled
  1157. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1158. if enabled == 0 then
  1159. gene.enabled = false
  1160. else
  1161. gene.enabled = true
  1162. end
  1163.  
  1164. end
  1165. end
  1166. end
  1167. file:close()
  1168. pool.repeatcounter = 0
  1169. while fitnessAlreadyMeasured() do
  1170. nextGenome()
  1171. timebonus = 0
  1172. end
  1173. initializeRun()
  1174. pool.currentFrame = pool.currentFrame + 1
  1175. end
  1176.  
  1177. function loadPool()
  1178. local filename = forms.gettext(saveLoadFile)
  1179. loadFile(filename)
  1180. end
  1181.  
  1182. function playTop()
  1183. local maxfitness = 0
  1184. local maxs, maxg
  1185. for s,species in pairs(pool.species) do
  1186. for g,genome in pairs(species.genomes) do
  1187. if genome.fitness > maxfitness then
  1188. maxfitness = genome.fitness
  1189. maxs = s
  1190. maxg = g
  1191. end
  1192. end
  1193. end
  1194.  
  1195. pool.currentSpecies = maxs
  1196. pool.currentGenome = maxg
  1197. pool.maxFitness = maxfitness
  1198. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1199. initializeRun()
  1200. pool.currentFrame = pool.currentFrame + 1
  1201. return
  1202. end
  1203.  
  1204. function onExit()
  1205. forms.destroy(form)
  1206. end
  1207.  
  1208. writeFile("temp.pool")
  1209.  
  1210. event.onexit(onExit)
  1211.  
  1212. form = forms.newform(200, 260, "Fitness")
  1213. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1214. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1215. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1216. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1217. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1218. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1219. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1220. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1221. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1222. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1223. local timebonus = 0
  1224.  
  1225. while true do
  1226. local backgroundColor = 0xD0FFFFFF
  1227. if not forms.ischecked(hideBanner) then
  1228. gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1229. end
  1230. local species = pool.species[pool.currentSpecies]
  1231. local genome = species.genomes[pool.currentGenome]
  1232.  
  1233. if forms.ischecked(showNetwork) then
  1234. displayGenome(genome)
  1235. end
  1236.  
  1237. if pool.currentFrame%5 == 0 then
  1238. evaluateCurrent()
  1239. end
  1240.  
  1241. joypad.set(controller)
  1242.  
  1243. getPositions()
  1244. if marioX > rightmost and marioYextended < 474 then
  1245. rightmost = marioX
  1246. timeout = TimeoutConstant
  1247. elseif marioX <= rightmost and marioX > rightmost-40 then --prevents the fitness from jumping back when mario is stuck and gets out
  1248. timebonus = timebonus + 1/2
  1249. end
  1250.  
  1251. timeout = timeout - 1
  1252.  
  1253.  
  1254. local timeoutBonus = pool.currentFrame / 4
  1255. if timeout + timeoutBonus <= 0 then
  1256. local fitness = rightmost - pool.currentFrame / 2 -39.5 +timebonus
  1257. if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1258. fitness = fitness + 1000
  1259. end
  1260. if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1261. fitness = fitness + 1000
  1262. end
  1263. if fitness == 0 then
  1264. fitness = -1
  1265. end
  1266. genome.fitness = fitness
  1267.  
  1268. if fitness > pool.maxFitness then
  1269. pool.maxFitness = fitness
  1270. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1271. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1272. end
  1273.  
  1274. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1275. --pool.currentSpecies = 1
  1276. --pool.currentGenome = 1
  1277. while fitnessAlreadyMeasured() do
  1278. nextGenome()
  1279. timebonus = 0
  1280. end
  1281. initializeRun()
  1282. end
  1283.  
  1284. local measured = 0
  1285. local total = 0
  1286. for _,species in pairs(pool.species) do
  1287. for _,genome in pairs(species.genomes) do
  1288. total = total + 1
  1289. if genome.fitness ~= 0 then
  1290. measured = measured + 1
  1291. end
  1292. end
  1293. end
  1294. if not forms.ischecked(hideBanner) then
  1295. gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1296. gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 -39.5 + timebonus), 0xFF000000, 11) -- - (timeout + timeoutBonus)*2/3 removed and replaced by timebonus
  1297. gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1298. end
  1299.  
  1300. pool.currentFrame = pool.currentFrame + 1
  1301.  
  1302. emu.frameadvance();
  1303. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement