Advertisement
Guest User

Untitled

a guest
Jun 14th, 2015
301
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 41.16 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4.  
  5. if gameinfo.getromname() == "Super Mario World (USA)" then
  6. Filename = "DP1.state"
  7. ButtonNames = {
  8. "A",
  9. "B",
  10. "X",
  11. "Y",
  12. "Up",
  13. "Down",
  14. "Left",
  15. "Right",
  16. }
  17. elseif gameinfo.getromname() == "Super Mario Bros." then
  18. Filename = "SMB1-1.state"
  19. ButtonNames = {
  20. "A",
  21. "B",
  22. "Up",
  23. "Down",
  24. "Left",
  25. "Right",
  26. }
  27. end
  28.  
  29. BoxRadius = 6
  30. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  31.  
  32. Inputs = InputSize+1
  33. Outputs = #ButtonNames
  34.  
  35. Population = 300
  36. DeltaDisjoint = 2.0
  37. DeltaWeights = 0.4
  38. DeltaThreshold = 1.0
  39.  
  40. StaleSpecies = 15
  41.  
  42. MutateConnectionsChance = 0.25
  43. PerturbChance = 0.90
  44. CrossoverChance = 0.75
  45. LinkMutationChance = 2.0
  46. NodeMutationChance = 0.50
  47. BiasMutationChance = 0.40
  48. StepSize = 0.1
  49. DisableMutationChance = 0.4
  50. EnableMutationChance = 0.2
  51.  
  52. TimeoutConstant = 20
  53.  
  54. MaxNodes = 1000000
  55.  
  56. function getPositions()
  57. if gameinfo.getromname() == "Super Mario World (USA)" then
  58. marioX = memory.read_s16_le(0x94)
  59. marioY = memory.read_s16_le(0x96)
  60.  
  61. local layer1x = memory.read_s16_le(0x1A);
  62. local layer1y = memory.read_s16_le(0x1C);
  63.  
  64. screenX = marioX-layer1x
  65. screenY = marioY-layer1y
  66. elseif gameinfo.getromname() == "Super Mario Bros." then
  67. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  68. marioY = memory.readbyte(0x03B8)+16
  69.  
  70. screenX = memory.readbyte(0x03AD)
  71. screenY = memory.readbyte(0x03B8)
  72. end
  73. end
  74.  
  75. function getTile(dx, dy)
  76. if gameinfo.getromname() == "Super Mario World (USA)" then
  77. x = math.floor((marioX+dx+8)/16)
  78. y = math.floor((marioY+dy)/16)
  79.  
  80. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  81. elseif gameinfo.getromname() == "Super Mario Bros." then
  82. local x = marioX + dx + 8
  83. local y = marioY + dy - 16
  84. local page = math.floor(x/256)%2
  85.  
  86. local subx = math.floor((x%256)/16)
  87. local suby = math.floor((y - 32)/16)
  88. local addr = 0x500 + page*13*16+suby*16+subx
  89.  
  90. if suby >= 13 or suby < 0 then
  91. return 0
  92. end
  93.  
  94. if memory.readbyte(addr) ~= 0 then
  95. return 1
  96. else
  97. return 0
  98. end
  99. end
  100. end
  101.  
  102. function getSprites()
  103. if gameinfo.getromname() == "Super Mario World (USA)" then
  104. local sprites = {}
  105. for slot=0,11 do
  106. local status = memory.readbyte(0x14C8+slot)
  107. if status ~= 0 then
  108. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  109. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  110. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  111. end
  112. end
  113.  
  114. return sprites
  115. elseif gameinfo.getromname() == "Super Mario Bros." then
  116. local sprites = {}
  117. for slot=0,4 do
  118. local enemy = memory.readbyte(0xF+slot)
  119. if enemy ~= 0 then
  120. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  121. local ey = memory.readbyte(0xCF + slot)+24
  122. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  123. end
  124. end
  125.  
  126. return sprites
  127. end
  128. end
  129.  
  130. function getExtendedSprites()
  131. if gameinfo.getromname() == "Super Mario World (USA)" then
  132. local extended = {}
  133. for slot=0,11 do
  134. local number = memory.readbyte(0x170B+slot)
  135. if number ~= 0 then
  136. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  137. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  138. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  139. end
  140. end
  141.  
  142. return extended
  143. elseif gameinfo.getromname() == "Super Mario Bros." then
  144. return {}
  145. end
  146. end
  147.  
  148. function getInputs()
  149. getPositions()
  150.  
  151. sprites = getSprites()
  152. extended = getExtendedSprites()
  153.  
  154. local inputs = {}
  155.  
  156. for dy=-BoxRadius*16,BoxRadius*16,16 do
  157. for dx=-BoxRadius*16,BoxRadius*16,16 do
  158. inputs[#inputs+1] = 0
  159.  
  160. tile = getTile(dx, dy)
  161. if tile == 1 and marioY+dy < 0x1B0 then
  162. inputs[#inputs] = 1
  163. end
  164.  
  165. for i = 1,#sprites do
  166. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  167. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  168. if distx <= 8 and disty <= 8 then
  169. inputs[#inputs] = -1
  170. end
  171. end
  172.  
  173. for i = 1,#extended do
  174. distx = math.abs(extended[i]["x"] - (marioX+dx))
  175. disty = math.abs(extended[i]["y"] - (marioY+dy))
  176. if distx < 8 and disty < 8 then
  177. inputs[#inputs] = -1
  178. end
  179. end
  180. end
  181. end
  182.  
  183. --mariovx = memory.read_s8(0x7B)
  184. --mariovy = memory.read_s8(0x7D)
  185.  
  186. return inputs
  187. end
  188.  
  189. function sigmoid(x)
  190. return 2/(1+math.exp(-4.9*x))-1
  191. end
  192.  
  193. function newInnovation()
  194. pool.innovation = pool.innovation + 1
  195. return pool.innovation
  196. end
  197.  
  198. function newPool()
  199. local pool = {}
  200. pool.species = {}
  201. pool.generation = 0
  202. pool.innovation = Outputs
  203. pool.currentSpecies = 1
  204. pool.currentGenome = 1
  205. pool.currentFrame = 0
  206. pool.maxFitness = 0
  207.  
  208. return pool
  209. end
  210.  
  211. function newSpecies()
  212. local species = {}
  213. species.topFitness = 0
  214. species.staleness = 0
  215. species.genomes = {}
  216. species.averageFitness = 0
  217.  
  218. return species
  219. end
  220.  
  221. function newGenome()
  222. local genome = {}
  223. genome.genes = {}
  224. genome.fitness = 0
  225. genome.adjustedFitness = 0
  226. genome.network = {}
  227. genome.maxneuron = 0
  228. genome.globalRank = 0
  229. genome.mutationRates = {}
  230. genome.mutationRates["connections"] = MutateConnectionsChance
  231. genome.mutationRates["link"] = LinkMutationChance
  232. genome.mutationRates["bias"] = BiasMutationChance
  233. genome.mutationRates["node"] = NodeMutationChance
  234. genome.mutationRates["enable"] = EnableMutationChance
  235. genome.mutationRates["disable"] = DisableMutationChance
  236. genome.mutationRates["step"] = StepSize
  237.  
  238. return genome
  239. end
  240.  
  241. function copyGenome(genome)
  242. local genome2 = newGenome()
  243. for g=1,#genome.genes do
  244. table.insert(genome2.genes, copyGene(genome.genes[g]))
  245. end
  246. genome2.maxneuron = genome.maxneuron
  247. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  248. genome2.mutationRates["link"] = genome.mutationRates["link"]
  249. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  250. genome2.mutationRates["node"] = genome.mutationRates["node"]
  251. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  252. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  253.  
  254. return genome2
  255. end
  256.  
  257. function basicGenome()
  258. local genome = newGenome()
  259. local innovation = 1
  260.  
  261. genome.maxneuron = Inputs
  262. mutate(genome)
  263.  
  264. return genome
  265. end
  266.  
  267. function newGene()
  268. local gene = {}
  269. gene.into = 0
  270. gene.out = 0
  271. gene.weight = 0.0
  272. gene.enabled = true
  273. gene.innovation = 0
  274.  
  275. return gene
  276. end
  277.  
  278. function copyGene(gene)
  279. local gene2 = newGene()
  280. gene2.into = gene.into
  281. gene2.out = gene.out
  282. gene2.weight = gene.weight
  283. gene2.enabled = gene.enabled
  284. gene2.innovation = gene.innovation
  285.  
  286. return gene2
  287. end
  288.  
  289. function newNeuron()
  290. local neuron = {}
  291. neuron.incoming = {}
  292. neuron.value = 0.0
  293.  
  294. return neuron
  295. end
  296.  
  297. function generateNetwork(genome)
  298. local network = {}
  299. network.neurons = {}
  300.  
  301. for i=1,Inputs do
  302. network.neurons[i] = newNeuron()
  303. end
  304.  
  305. for o=1,Outputs do
  306. network.neurons[MaxNodes+o] = newNeuron()
  307. end
  308.  
  309. table.sort(genome.genes, function (a,b)
  310. return (a.out < b.out)
  311. end)
  312. for i=1,#genome.genes do
  313. local gene = genome.genes[i]
  314. if gene.enabled then
  315. if network.neurons[gene.out] == nil then
  316. network.neurons[gene.out] = newNeuron()
  317. end
  318. local neuron = network.neurons[gene.out]
  319. table.insert(neuron.incoming, gene)
  320. if network.neurons[gene.into] == nil then
  321. network.neurons[gene.into] = newNeuron()
  322. end
  323. end
  324. end
  325.  
  326. genome.network = network
  327. end
  328.  
  329. function evaluateNetwork(network, inputs)
  330. table.insert(inputs, 1)
  331. if #inputs ~= Inputs then
  332. console.writeline("Incorrect number of neural network inputs.")
  333. return {}
  334. end
  335.  
  336. for i=1,Inputs do
  337. network.neurons[i].value = inputs[i]
  338. end
  339.  
  340. for _,neuron in pairs(network.neurons) do
  341. local sum = 0
  342. for j = 1,#neuron.incoming do
  343. local incoming = neuron.incoming[j]
  344. local other = network.neurons[incoming.into]
  345. sum = sum + incoming.weight * other.value
  346. end
  347.  
  348. if #neuron.incoming > 0 then
  349. neuron.value = sigmoid(sum)
  350. end
  351. end
  352.  
  353. local outputs = {}
  354. for o=1,Outputs do
  355. local button = "P1 " .. ButtonNames[o]
  356. if network.neurons[MaxNodes+o].value > 0 then
  357. outputs[button] = true
  358. else
  359. outputs[button] = false
  360. end
  361. end
  362.  
  363. return outputs
  364. end
  365.  
  366. function crossover(g1, g2)
  367. -- Make sure g1 is the higher fitness genome
  368. if g2.fitness > g1.fitness then
  369. tempg = g1
  370. g1 = g2
  371. g2 = tempg
  372. end
  373.  
  374. local child = newGenome()
  375.  
  376. local innovations2 = {}
  377. for i=1,#g2.genes do
  378. local gene = g2.genes[i]
  379. innovations2[gene.innovation] = gene
  380. end
  381.  
  382. for i=1,#g1.genes do
  383. local gene1 = g1.genes[i]
  384. local gene2 = innovations2[gene1.innovation]
  385. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  386. table.insert(child.genes, copyGene(gene2))
  387. else
  388. table.insert(child.genes, copyGene(gene1))
  389. end
  390. end
  391.  
  392. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  393.  
  394. for mutation,rate in pairs(g1.mutationRates) do
  395. child.mutationRates[mutation] = rate
  396. end
  397.  
  398. return child
  399. end
  400.  
  401. function randomNeuron(genes, nonInput)
  402. local neurons = {}
  403. if not nonInput then
  404. for i=1,Inputs do
  405. neurons[i] = true
  406. end
  407. end
  408. for o=1,Outputs do
  409. neurons[MaxNodes+o] = true
  410. end
  411. for i=1,#genes do
  412. if (not nonInput) or genes[i].into > Inputs then
  413. neurons[genes[i].into] = true
  414. end
  415. if (not nonInput) or genes[i].out > Inputs then
  416. neurons[genes[i].out] = true
  417. end
  418. end
  419.  
  420. local count = 0
  421. for _,_ in pairs(neurons) do
  422. count = count + 1
  423. end
  424. local n = math.random(1, count)
  425.  
  426. for k,v in pairs(neurons) do
  427. n = n-1
  428. if n == 0 then
  429. return k
  430. end
  431. end
  432.  
  433. return 0
  434. end
  435.  
  436. function containsLink(genes, link)
  437. for i=1,#genes do
  438. local gene = genes[i]
  439. if gene.into == link.into and gene.out == link.out then
  440. return true
  441. end
  442. end
  443. end
  444.  
  445. function pointMutate(genome)
  446. local step = genome.mutationRates["step"]
  447.  
  448. for i=1,#genome.genes do
  449. local gene = genome.genes[i]
  450. if math.random() < PerturbChance then
  451. gene.weight = gene.weight + math.random() * step*2 - step
  452. else
  453. gene.weight = math.random()*4-2
  454. end
  455. end
  456. end
  457.  
  458. function linkMutate(genome, forceBias)
  459. local neuron1 = randomNeuron(genome.genes, false)
  460. local neuron2 = randomNeuron(genome.genes, true)
  461.  
  462. local newLink = newGene()
  463. if neuron1 <= Inputs and neuron2 <= Inputs then
  464. --Both input nodes
  465. return
  466. end
  467. if neuron2 <= Inputs then
  468. -- Swap output and input
  469. local temp = neuron1
  470. neuron1 = neuron2
  471. neuron2 = temp
  472. end
  473.  
  474. newLink.into = neuron1
  475. newLink.out = neuron2
  476. if forceBias then
  477. newLink.into = Inputs
  478. end
  479.  
  480. if containsLink(genome.genes, newLink) then
  481. return
  482. end
  483. newLink.innovation = newInnovation()
  484. newLink.weight = math.random()*4-2
  485.  
  486. table.insert(genome.genes, newLink)
  487. end
  488.  
  489. function nodeMutate(genome)
  490. if #genome.genes == 0 then
  491. return
  492. end
  493.  
  494. genome.maxneuron = genome.maxneuron + 1
  495.  
  496. local gene = genome.genes[math.random(1,#genome.genes)]
  497. if not gene.enabled then
  498. return
  499. end
  500. gene.enabled = false
  501.  
  502. local gene1 = copyGene(gene)
  503. gene1.out = genome.maxneuron
  504. gene1.weight = 1.0
  505. gene1.innovation = newInnovation()
  506. gene1.enabled = true
  507. table.insert(genome.genes, gene1)
  508.  
  509. local gene2 = copyGene(gene)
  510. gene2.into = genome.maxneuron
  511. gene2.innovation = newInnovation()
  512. gene2.enabled = true
  513. table.insert(genome.genes, gene2)
  514. end
  515.  
  516. function enableDisableMutate(genome, enable)
  517. local candidates = {}
  518. for _,gene in pairs(genome.genes) do
  519. if gene.enabled == not enable then
  520. table.insert(candidates, gene)
  521. end
  522. end
  523.  
  524. if #candidates == 0 then
  525. return
  526. end
  527.  
  528. local gene = candidates[math.random(1,#candidates)]
  529. gene.enabled = not gene.enabled
  530. end
  531.  
  532. function mutate(genome)
  533. for mutation,rate in pairs(genome.mutationRates) do
  534. if math.random(1,2) == 1 then
  535. genome.mutationRates[mutation] = 0.95*rate
  536. else
  537. genome.mutationRates[mutation] = 1.05263*rate
  538. end
  539. end
  540.  
  541. if math.random() < genome.mutationRates["connections"] then
  542. pointMutate(genome)
  543. end
  544.  
  545. local p = genome.mutationRates["link"]
  546. while p > 0 do
  547. if math.random() < p then
  548. linkMutate(genome, false)
  549. end
  550. p = p - 1
  551. end
  552.  
  553. p = genome.mutationRates["bias"]
  554. while p > 0 do
  555. if math.random() < p then
  556. linkMutate(genome, true)
  557. end
  558. p = p - 1
  559. end
  560.  
  561. p = genome.mutationRates["node"]
  562. while p > 0 do
  563. if math.random() < p then
  564. nodeMutate(genome)
  565. end
  566. p = p - 1
  567. end
  568.  
  569. p = genome.mutationRates["enable"]
  570. while p > 0 do
  571. if math.random() < p then
  572. enableDisableMutate(genome, true)
  573. end
  574. p = p - 1
  575. end
  576.  
  577. p = genome.mutationRates["disable"]
  578. while p > 0 do
  579. if math.random() < p then
  580. enableDisableMutate(genome, false)
  581. end
  582. p = p - 1
  583. end
  584. end
  585.  
  586. function disjoint(genes1, genes2)
  587. local i1 = {}
  588. for i = 1,#genes1 do
  589. local gene = genes1[i]
  590. i1[gene.innovation] = true
  591. end
  592.  
  593. local i2 = {}
  594. for i = 1,#genes2 do
  595. local gene = genes2[i]
  596. i2[gene.innovation] = true
  597. end
  598.  
  599. local disjointGenes = 0
  600. for i = 1,#genes1 do
  601. local gene = genes1[i]
  602. if not i2[gene.innovation] then
  603. disjointGenes = disjointGenes+1
  604. end
  605. end
  606.  
  607. for i = 1,#genes2 do
  608. local gene = genes2[i]
  609. if not i1[gene.innovation] then
  610. disjointGenes = disjointGenes+1
  611. end
  612. end
  613.  
  614. local n = math.max(#genes1, #genes2)
  615.  
  616. return disjointGenes / n
  617. end
  618.  
  619. function weights(genes1, genes2)
  620. local i2 = {}
  621. for i = 1,#genes2 do
  622. local gene = genes2[i]
  623. i2[gene.innovation] = gene
  624. end
  625.  
  626. local sum = 0
  627. local coincident = 0
  628. for i = 1,#genes1 do
  629. local gene = genes1[i]
  630. if i2[gene.innovation] ~= nil then
  631. local gene2 = i2[gene.innovation]
  632. sum = sum + math.abs(gene.weight - gene2.weight)
  633. coincident = coincident + 1
  634. end
  635. end
  636.  
  637. return sum / coincident
  638. end
  639.  
  640. function sameSpecies(genome1, genome2)
  641. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  642. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  643. return dd + dw < DeltaThreshold
  644. end
  645.  
  646. function rankGlobally()
  647. local global = {}
  648. for s = 1,#pool.species do
  649. local species = pool.species[s]
  650. for g = 1,#species.genomes do
  651. table.insert(global, species.genomes[g])
  652. end
  653. end
  654. table.sort(global, function (a,b)
  655. return (a.fitness < b.fitness)
  656. end)
  657.  
  658. for g=1,#global do
  659. global[g].globalRank = g
  660. end
  661. end
  662.  
  663. function calculateAverageFitness(species)
  664. local total = 0
  665.  
  666. for g=1,#species.genomes do
  667. local genome = species.genomes[g]
  668. total = total + genome.globalRank
  669. end
  670.  
  671. species.averageFitness = total / #species.genomes
  672. end
  673.  
  674. function totalAverageFitness()
  675. local total = 0
  676. for s = 1,#pool.species do
  677. local species = pool.species[s]
  678. total = total + species.averageFitness
  679. end
  680.  
  681. return total
  682. end
  683.  
  684. function cullSpecies(cutToOne)
  685. for s = 1,#pool.species do
  686. local species = pool.species[s]
  687.  
  688. table.sort(species.genomes, function (a,b)
  689. return (a.fitness > b.fitness)
  690. end)
  691.  
  692. local remaining = math.ceil(#species.genomes/2)
  693. if cutToOne then
  694. remaining = 1
  695. end
  696. while #species.genomes > remaining do
  697. table.remove(species.genomes)
  698. end
  699. end
  700. end
  701.  
  702. function breedChild(species)
  703. local child = {}
  704. if math.random() < CrossoverChance then
  705. g1 = species.genomes[math.random(1, #species.genomes)]
  706. g2 = species.genomes[math.random(1, #species.genomes)]
  707. child = crossover(g1, g2)
  708. else
  709. g = species.genomes[math.random(1, #species.genomes)]
  710. child = copyGenome(g)
  711. end
  712.  
  713. mutate(child)
  714.  
  715. return child
  716. end
  717.  
  718. function removeStaleSpecies()
  719. local survived = {}
  720.  
  721. for s = 1,#pool.species do
  722. local species = pool.species[s]
  723.  
  724. table.sort(species.genomes, function (a,b)
  725. return (a.fitness > b.fitness)
  726. end)
  727.  
  728. if species.genomes[1].fitness > species.topFitness then
  729. species.topFitness = species.genomes[1].fitness
  730. species.staleness = 0
  731. else
  732. species.staleness = species.staleness + 1
  733. end
  734. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  735. table.insert(survived, species)
  736. end
  737. end
  738.  
  739. pool.species = survived
  740. end
  741.  
  742. function removeWeakSpecies()
  743. local survived = {}
  744.  
  745. local sum = totalAverageFitness()
  746. for s = 1,#pool.species do
  747. local species = pool.species[s]
  748. breed = math.floor(species.averageFitness / sum * Population)
  749. if breed >= 1 then
  750. table.insert(survived, species)
  751. end
  752. end
  753.  
  754. pool.species = survived
  755. end
  756.  
  757.  
  758. function addToSpecies(child)
  759. local foundSpecies = false
  760. for s=1,#pool.species do
  761. local species = pool.species[s]
  762. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  763. table.insert(species.genomes, child)
  764. foundSpecies = true
  765. end
  766. end
  767.  
  768. if not foundSpecies then
  769. local childSpecies = newSpecies()
  770. table.insert(childSpecies.genomes, child)
  771. table.insert(pool.species, childSpecies)
  772. end
  773. end
  774.  
  775. function newGeneration()
  776. cullSpecies(false) -- Cull the bottom half of each species
  777. rankGlobally()
  778. removeStaleSpecies()
  779. rankGlobally()
  780. for s = 1,#pool.species do
  781. local species = pool.species[s]
  782. calculateAverageFitness(species)
  783. end
  784. removeWeakSpecies()
  785. local sum = totalAverageFitness()
  786. local children = {}
  787. for s = 1,#pool.species do
  788. local species = pool.species[s]
  789. breed = math.floor(species.averageFitness / sum * Population) - 1
  790. for i=1,breed do
  791. table.insert(children, breedChild(species))
  792. end
  793. end
  794. cullSpecies(true) -- Cull all but the top member of each species
  795. while #children + #pool.species < Population do
  796. local species = pool.species[math.random(1, #pool.species)]
  797. table.insert(children, breedChild(species))
  798. end
  799. for c=1,#children do
  800. local child = children[c]
  801. addToSpecies(child)
  802. end
  803. pool.generation = pool.generation + 1
  804.  
  805. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  806. end
  807.  
  808. function initializePool()
  809. pool = newPool()
  810.  
  811. for i=1,Population do
  812. basic = basicGenome()
  813. addToSpecies(basic)
  814. end
  815.  
  816. initializeRun()
  817. end
  818.  
  819. function clearJoypad()
  820. controller = {}
  821. for b = 1,#ButtonNames do
  822. controller["P1 " .. ButtonNames[b]] = false
  823. end
  824. joypad.set(controller)
  825. end
  826.  
  827. function initializeRun()
  828. savestate.load(Filename);
  829. rightmost = 0
  830. pool.currentFrame = 0
  831. timeout = TimeoutConstant
  832. clearJoypad()
  833.  
  834. local species = pool.species[pool.currentSpecies]
  835. local genome = species.genomes[pool.currentGenome]
  836. generateNetwork(genome)
  837. evaluateCurrent()
  838. end
  839.  
  840. function evaluateCurrent()
  841. local species = pool.species[pool.currentSpecies]
  842. local genome = species.genomes[pool.currentGenome]
  843.  
  844. inputs = getInputs()
  845. controller = evaluateNetwork(genome.network, inputs)
  846.  
  847. if controller["P1 Left"] and controller["P1 Right"] then
  848. controller["P1 Left"] = false
  849. controller["P1 Right"] = false
  850. end
  851. if controller["P1 Up"] and controller["P1 Down"] then
  852. controller["P1 Up"] = false
  853. controller["P1 Down"] = false
  854. end
  855.  
  856. joypad.set(controller)
  857. end
  858.  
  859. if pool == nil then
  860. initializePool()
  861. end
  862.  
  863.  
  864. function nextGenome()
  865. pool.currentGenome = pool.currentGenome + 1
  866. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  867. pool.currentGenome = 1
  868. pool.currentSpecies = pool.currentSpecies+1
  869. if pool.currentSpecies > #pool.species then
  870. newGeneration()
  871. pool.currentSpecies = 1
  872. end
  873. end
  874. end
  875.  
  876. function fitnessAlreadyMeasured()
  877. local species = pool.species[pool.currentSpecies]
  878. local genome = species.genomes[pool.currentGenome]
  879.  
  880. return genome.fitness ~= 0
  881. end
  882.  
  883. function displayGenome(genome)
  884. local network = genome.network
  885. local cells = {}
  886. local i = 1
  887. local cell = {}
  888. for dy=-BoxRadius,BoxRadius do
  889. for dx=-BoxRadius,BoxRadius do
  890. cell = {}
  891. cell.x = 50+5*dx
  892. cell.y = 70+5*dy
  893. cell.value = network.neurons[i].value
  894. cells[i] = cell
  895. i = i + 1
  896. end
  897. end
  898. local biasCell = {}
  899. biasCell.x = 80
  900. biasCell.y = 110
  901. biasCell.value = network.neurons[Inputs].value
  902. cells[Inputs] = biasCell
  903.  
  904. for o = 1,Outputs do
  905. cell = {}
  906. cell.x = 220
  907. cell.y = 30 + 8 * o
  908. cell.value = network.neurons[MaxNodes + o].value
  909. cells[MaxNodes+o] = cell
  910. local color
  911. if cell.value > 0 then
  912. color = 0xFF0000FF
  913. else
  914. color = 0xFF000000
  915. end
  916. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  917. end
  918.  
  919. for n,neuron in pairs(network.neurons) do
  920. cell = {}
  921. if n > Inputs and n <= MaxNodes then
  922. cell.x = 140
  923. cell.y = 40
  924. cell.value = neuron.value
  925. cells[n] = cell
  926. end
  927. end
  928.  
  929. for n=1,4 do
  930. for _,gene in pairs(genome.genes) do
  931. if gene.enabled then
  932. local c1 = cells[gene.into]
  933. local c2 = cells[gene.out]
  934. if gene.into > Inputs and gene.into <= MaxNodes then
  935. c1.x = 0.75*c1.x + 0.25*c2.x
  936. if c1.x >= c2.x then
  937. c1.x = c1.x - 40
  938. end
  939. if c1.x < 90 then
  940. c1.x = 90
  941. end
  942.  
  943. if c1.x > 220 then
  944. c1.x = 220
  945. end
  946. c1.y = 0.75*c1.y + 0.25*c2.y
  947.  
  948. end
  949. if gene.out > Inputs and gene.out <= MaxNodes then
  950. c2.x = 0.25*c1.x + 0.75*c2.x
  951. if c1.x >= c2.x then
  952. c2.x = c2.x + 40
  953. end
  954. if c2.x < 90 then
  955. c2.x = 90
  956. end
  957. if c2.x > 220 then
  958. c2.x = 220
  959. end
  960. c2.y = 0.25*c1.y + 0.75*c2.y
  961. end
  962. end
  963. end
  964. end
  965.  
  966. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  967. for n,cell in pairs(cells) do
  968. if n > Inputs or cell.value ~= 0 then
  969. local color = math.floor((cell.value+1)/2*256)
  970. if color > 255 then color = 255 end
  971. if color < 0 then color = 0 end
  972. local opacity = 0xFF000000
  973. if cell.value == 0 then
  974. opacity = 0x50000000
  975. end
  976. color = opacity + color*0x10000 + color*0x100 + color
  977. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  978. end
  979. end
  980. for _,gene in pairs(genome.genes) do
  981. if gene.enabled then
  982. local c1 = cells[gene.into]
  983. local c2 = cells[gene.out]
  984. local opacity = 0xA0000000
  985. if c1.value == 0 then
  986. opacity = 0x20000000
  987. end
  988.  
  989. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  990. if gene.weight > 0 then
  991. color = opacity + 0x8000 + 0x10000*color
  992. else
  993. color = opacity + 0x800000 + 0x100*color
  994. end
  995. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  996. end
  997. end
  998.  
  999. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1000.  
  1001. if forms.ischecked(showMutationRates) then
  1002. local pos = 100
  1003. for mutation,rate in pairs(genome.mutationRates) do
  1004. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1005. pos = pos + 8
  1006. end
  1007. end
  1008. end
  1009.  
  1010. function writeFile(filename)
  1011. local file = io.open(filename, "w")
  1012. file:write(pool.generation .. "\n")
  1013. file:write(pool.maxFitness .. "\n")
  1014. file:write(#pool.species .. "\n")
  1015. for n,species in pairs(pool.species) do
  1016. file:write(species.topFitness .. "\n")
  1017. file:write(species.staleness .. "\n")
  1018. file:write(#species.genomes .. "\n")
  1019. for m,genome in pairs(species.genomes) do
  1020. file:write(genome.fitness .. "\n")
  1021. file:write(genome.maxneuron .. "\n")
  1022. for mutation,rate in pairs(genome.mutationRates) do
  1023. file:write(mutation .. "\n")
  1024. file:write(rate .. "\n")
  1025. end
  1026. file:write("done\n")
  1027.  
  1028. file:write(#genome.genes .. "\n")
  1029. for l,gene in pairs(genome.genes) do
  1030. file:write(gene.into .. " ")
  1031. file:write(gene.out .. " ")
  1032. file:write(gene.weight .. " ")
  1033. file:write(gene.innovation .. " ")
  1034. if(gene.enabled) then
  1035. file:write("1\n")
  1036. else
  1037. file:write("0\n")
  1038. end
  1039. end
  1040. end
  1041. end
  1042. file:close()
  1043. end
  1044.  
  1045. function savePool()
  1046. local filename = forms.gettext(saveLoadFile)
  1047. writeFile(filename)
  1048. end
  1049.  
  1050. function loadFile(filename)
  1051. local file = io.open(filename, "r")
  1052. pool = newPool()
  1053. pool.generation = file:read("*number")
  1054. pool.maxFitness = file:read("*number")
  1055. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1056. local numSpecies = file:read("*number")
  1057. for s=1,numSpecies do
  1058. local species = newSpecies()
  1059. table.insert(pool.species, species)
  1060. species.topFitness = file:read("*number")
  1061. species.staleness = file:read("*number")
  1062. local numGenomes = file:read("*number")
  1063. for g=1,numGenomes do
  1064. local genome = newGenome()
  1065. table.insert(species.genomes, genome)
  1066. genome.fitness = file:read("*number")
  1067. genome.maxneuron = file:read("*number")
  1068. local line = file:read("*line")
  1069. while line ~= "done" do
  1070. genome.mutationRates[line] = file:read("*number")
  1071. line = file:read("*line")
  1072. end
  1073. local numGenes = file:read("*number")
  1074. for n=1,numGenes do
  1075. local gene = newGene()
  1076. table.insert(genome.genes, gene)
  1077. local enabled
  1078. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1079. if enabled == 0 then
  1080. gene.enabled = false
  1081. else
  1082. gene.enabled = true
  1083. end
  1084.  
  1085. end
  1086. end
  1087. end
  1088. file:close()
  1089.  
  1090. while fitnessAlreadyMeasured() do
  1091. nextGenome()
  1092. end
  1093. initializeRun()
  1094. pool.currentFrame = pool.currentFrame + 1
  1095. end
  1096.  
  1097. function loadPool()
  1098. local filename = forms.gettext(saveLoadFile)
  1099. loadFile(filename)
  1100. end
  1101.  
  1102. function playTop()
  1103. local maxfitness = 0
  1104. local maxs, maxg
  1105. for s,species in pairs(pool.species) do
  1106. for g,genome in pairs(species.genomes) do
  1107. if genome.fitness > maxfitness then
  1108. maxfitness = genome.fitness
  1109. maxs = s
  1110. maxg = g
  1111. end
  1112. end
  1113. end
  1114.  
  1115. pool.currentSpecies = maxs
  1116. pool.currentGenome = maxg
  1117. pool.maxFitness = maxfitness
  1118. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1119. initializeRun()
  1120. pool.currentFrame = pool.currentFrame + 1
  1121. return
  1122. end
  1123.  
  1124. function onExit()
  1125. forms.destroy(form)
  1126. end
  1127.  
  1128. writeFile("temp.pool")
  1129.  
  1130. event.onexit(onExit)
  1131.  
  1132. form = forms.newform(200, 260, "Fitness")
  1133. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1134. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1135. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1136. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1137. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1138. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1139. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1140. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1141. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1142. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1143.  
  1144. while true do
  1145. local backgroundColor = 0xD0FFFFFF
  1146. if not forms.ischecked(hideBanner) then
  1147. gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1148. end
  1149.  
  1150. local species = pool.species[pool.currentSpecies]
  1151. local genome = species.genomes[pool.currentGenome]
  1152.  
  1153. if forms.ischecked(showNetwork) then
  1154. displayGenome(genome)
  1155. end
  1156.  
  1157. if pool.currentFrame%5 == 0 then
  1158. evaluateCurrent()
  1159. end
  1160.  
  1161. joypad.set(controller)
  1162.  
  1163. getPositions()
  1164. if marioX > rightmost then
  1165. rightmost = marioX
  1166. timeout = TimeoutConstant
  1167. end
  1168.  
  1169. timeout = timeout - 1
  1170.  
  1171. local score=memory.read_s32_le(0x0F34) * 5
  1172.  
  1173. local timeoutBonus = pool.currentFrame / 4
  1174. if timeout + timeoutBonus <= 0 then
  1175. local fitness = rightmost - pool.currentFrame / 2
  1176. fitness = fitness + score
  1177. if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1178. fitness = fitness + 1000
  1179. end
  1180. if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1181. fitness = fitness + 1000
  1182. end
  1183. if fitness == 0 then
  1184. fitness = -1
  1185. end
  1186. genome.fitness = fitness
  1187.  
  1188. if fitness > pool.maxFitness then
  1189. pool.maxFitness = fitness
  1190. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1191. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1192. end
  1193.  
  1194. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1195. pool.currentSpecies = 1
  1196. pool.currentGenome = 1
  1197. while fitnessAlreadyMeasured() do
  1198. nextGenome()
  1199. end
  1200. savePool()
  1201. initializeRun()
  1202. end
  1203.  
  1204. local measured = 0
  1205. local total = 0
  1206. for _,species in pairs(pool.species) do
  1207. for _,genome in pairs(species.genomes) do
  1208. total = total + 1
  1209. if genome.fitness ~= 0 then
  1210. measured = measured + 1
  1211. end
  1212. end
  1213. end
  1214. if not forms.ischecked(hideBanner) then
  1215. gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1216. gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3+score), 0xFF000000, 11)
  1217. gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1218. end
  1219.  
  1220. pool.currentFrame = pool.currentFrame + 1
  1221.  
  1222. emu.frameadvance();
  1223. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement