Advertisement
SethBling

NEATEvolve.lua

Jun 13th, 2015
656,037
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 28.73 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8. Filename = "DP1.state"
  9. ButtonNames = {
  10. "A",
  11. "B",
  12. "X",
  13. "Y",
  14. "Up",
  15. "Down",
  16. "Left",
  17. "Right",
  18. }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20. Filename = "SMB1-1.state"
  21. ButtonNames = {
  22. "A",
  23. "B",
  24. "Up",
  25. "Down",
  26. "Left",
  27. "Right",
  28. }
  29. end
  30.  
  31. BoxRadius = 6
  32. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  33.  
  34. Inputs = InputSize+1
  35. Outputs = #ButtonNames
  36.  
  37. Population = 300
  38. DeltaDisjoint = 2.0
  39. DeltaWeights = 0.4
  40. DeltaThreshold = 1.0
  41.  
  42. StaleSpecies = 15
  43.  
  44. MutateConnectionsChance = 0.25
  45. PerturbChance = 0.90
  46. CrossoverChance = 0.75
  47. LinkMutationChance = 2.0
  48. NodeMutationChance = 0.50
  49. BiasMutationChance = 0.40
  50. StepSize = 0.1
  51. DisableMutationChance = 0.4
  52. EnableMutationChance = 0.2
  53.  
  54. TimeoutConstant = 20
  55.  
  56. MaxNodes = 1000000
  57.  
  58. function getPositions()
  59. if gameinfo.getromname() == "Super Mario World (USA)" then
  60. marioX = memory.read_s16_le(0x94)
  61. marioY = memory.read_s16_le(0x96)
  62.  
  63. local layer1x = memory.read_s16_le(0x1A);
  64. local layer1y = memory.read_s16_le(0x1C);
  65.  
  66. screenX = marioX-layer1x
  67. screenY = marioY-layer1y
  68. elseif gameinfo.getromname() == "Super Mario Bros." then
  69. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  70. marioY = memory.readbyte(0x03B8)+16
  71.  
  72. screenX = memory.readbyte(0x03AD)
  73. screenY = memory.readbyte(0x03B8)
  74. end
  75. end
  76.  
  77. function getTile(dx, dy)
  78. if gameinfo.getromname() == "Super Mario World (USA)" then
  79. x = math.floor((marioX+dx+8)/16)
  80. y = math.floor((marioY+dy)/16)
  81.  
  82. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  83. elseif gameinfo.getromname() == "Super Mario Bros." then
  84. local x = marioX + dx + 8
  85. local y = marioY + dy - 16
  86. local page = math.floor(x/256)%2
  87.  
  88. local subx = math.floor((x%256)/16)
  89. local suby = math.floor((y - 32)/16)
  90. local addr = 0x500 + page*13*16+suby*16+subx
  91.  
  92. if suby >= 13 or suby < 0 then
  93. return 0
  94. end
  95.  
  96. if memory.readbyte(addr) ~= 0 then
  97. return 1
  98. else
  99. return 0
  100. end
  101. end
  102. end
  103.  
  104. function getSprites()
  105. if gameinfo.getromname() == "Super Mario World (USA)" then
  106. local sprites = {}
  107. for slot=0,11 do
  108. local status = memory.readbyte(0x14C8+slot)
  109. if status ~= 0 then
  110. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  111. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  112. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  113. end
  114. end
  115.  
  116. return sprites
  117. elseif gameinfo.getromname() == "Super Mario Bros." then
  118. local sprites = {}
  119. for slot=0,4 do
  120. local enemy = memory.readbyte(0xF+slot)
  121. if enemy ~= 0 then
  122. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  123. local ey = memory.readbyte(0xCF + slot)+24
  124. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  125. end
  126. end
  127.  
  128. return sprites
  129. end
  130. end
  131.  
  132. function getExtendedSprites()
  133. if gameinfo.getromname() == "Super Mario World (USA)" then
  134. local extended = {}
  135. for slot=0,11 do
  136. local number = memory.readbyte(0x170B+slot)
  137. if number ~= 0 then
  138. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  139. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  140. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  141. end
  142. end
  143.  
  144. return extended
  145. elseif gameinfo.getromname() == "Super Mario Bros." then
  146. return {}
  147. end
  148. end
  149.  
  150. function getInputs()
  151. getPositions()
  152.  
  153. sprites = getSprites()
  154. extended = getExtendedSprites()
  155.  
  156. local inputs = {}
  157.  
  158. for dy=-BoxRadius*16,BoxRadius*16,16 do
  159. for dx=-BoxRadius*16,BoxRadius*16,16 do
  160. inputs[#inputs+1] = 0
  161.  
  162. tile = getTile(dx, dy)
  163. if tile == 1 and marioY+dy < 0x1B0 then
  164. inputs[#inputs] = 1
  165. end
  166.  
  167. for i = 1,#sprites do
  168. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  169. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  170. if distx <= 8 and disty <= 8 then
  171. inputs[#inputs] = -1
  172. end
  173. end
  174.  
  175. for i = 1,#extended do
  176. distx = math.abs(extended[i]["x"] - (marioX+dx))
  177. disty = math.abs(extended[i]["y"] - (marioY+dy))
  178. if distx < 8 and disty < 8 then
  179. inputs[#inputs] = -1
  180. end
  181. end
  182. end
  183. end
  184.  
  185. --mariovx = memory.read_s8(0x7B)
  186. --mariovy = memory.read_s8(0x7D)
  187.  
  188. return inputs
  189. end
  190.  
  191. function sigmoid(x)
  192. return 2/(1+math.exp(-4.9*x))-1
  193. end
  194.  
  195. function newInnovation()
  196. pool.innovation = pool.innovation + 1
  197. return pool.innovation
  198. end
  199.  
  200. function newPool()
  201. local pool = {}
  202. pool.species = {}
  203. pool.generation = 0
  204. pool.innovation = Outputs
  205. pool.currentSpecies = 1
  206. pool.currentGenome = 1
  207. pool.currentFrame = 0
  208. pool.maxFitness = 0
  209.  
  210. return pool
  211. end
  212.  
  213. function newSpecies()
  214. local species = {}
  215. species.topFitness = 0
  216. species.staleness = 0
  217. species.genomes = {}
  218. species.averageFitness = 0
  219.  
  220. return species
  221. end
  222.  
  223. function newGenome()
  224. local genome = {}
  225. genome.genes = {}
  226. genome.fitness = 0
  227. genome.adjustedFitness = 0
  228. genome.network = {}
  229. genome.maxneuron = 0
  230. genome.globalRank = 0
  231. genome.mutationRates = {}
  232. genome.mutationRates["connections"] = MutateConnectionsChance
  233. genome.mutationRates["link"] = LinkMutationChance
  234. genome.mutationRates["bias"] = BiasMutationChance
  235. genome.mutationRates["node"] = NodeMutationChance
  236. genome.mutationRates["enable"] = EnableMutationChance
  237. genome.mutationRates["disable"] = DisableMutationChance
  238. genome.mutationRates["step"] = StepSize
  239.  
  240. return genome
  241. end
  242.  
  243. function copyGenome(genome)
  244. local genome2 = newGenome()
  245. for g=1,#genome.genes do
  246. table.insert(genome2.genes, copyGene(genome.genes[g]))
  247. end
  248. genome2.maxneuron = genome.maxneuron
  249. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  250. genome2.mutationRates["link"] = genome.mutationRates["link"]
  251. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  252. genome2.mutationRates["node"] = genome.mutationRates["node"]
  253. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  254. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  255.  
  256. return genome2
  257. end
  258.  
  259. function basicGenome()
  260. local genome = newGenome()
  261. local innovation = 1
  262.  
  263. genome.maxneuron = Inputs
  264. mutate(genome)
  265.  
  266. return genome
  267. end
  268.  
  269. function newGene()
  270. local gene = {}
  271. gene.into = 0
  272. gene.out = 0
  273. gene.weight = 0.0
  274. gene.enabled = true
  275. gene.innovation = 0
  276.  
  277. return gene
  278. end
  279.  
  280. function copyGene(gene)
  281. local gene2 = newGene()
  282. gene2.into = gene.into
  283. gene2.out = gene.out
  284. gene2.weight = gene.weight
  285. gene2.enabled = gene.enabled
  286. gene2.innovation = gene.innovation
  287.  
  288. return gene2
  289. end
  290.  
  291. function newNeuron()
  292. local neuron = {}
  293. neuron.incoming = {}
  294. neuron.value = 0.0
  295.  
  296. return neuron
  297. end
  298.  
  299. function generateNetwork(genome)
  300. local network = {}
  301. network.neurons = {}
  302.  
  303. for i=1,Inputs do
  304. network.neurons[i] = newNeuron()
  305. end
  306.  
  307. for o=1,Outputs do
  308. network.neurons[MaxNodes+o] = newNeuron()
  309. end
  310.  
  311. table.sort(genome.genes, function (a,b)
  312. return (a.out < b.out)
  313. end)
  314. for i=1,#genome.genes do
  315. local gene = genome.genes[i]
  316. if gene.enabled then
  317. if network.neurons[gene.out] == nil then
  318. network.neurons[gene.out] = newNeuron()
  319. end
  320. local neuron = network.neurons[gene.out]
  321. table.insert(neuron.incoming, gene)
  322. if network.neurons[gene.into] == nil then
  323. network.neurons[gene.into] = newNeuron()
  324. end
  325. end
  326. end
  327.  
  328. genome.network = network
  329. end
  330.  
  331. function evaluateNetwork(network, inputs)
  332. table.insert(inputs, 1)
  333. if #inputs ~= Inputs then
  334. console.writeline("Incorrect number of neural network inputs.")
  335. return {}
  336. end
  337.  
  338. for i=1,Inputs do
  339. network.neurons[i].value = inputs[i]
  340. end
  341.  
  342. for _,neuron in pairs(network.neurons) do
  343. local sum = 0
  344. for j = 1,#neuron.incoming do
  345. local incoming = neuron.incoming[j]
  346. local other = network.neurons[incoming.into]
  347. sum = sum + incoming.weight * other.value
  348. end
  349.  
  350. if #neuron.incoming > 0 then
  351. neuron.value = sigmoid(sum)
  352. end
  353. end
  354.  
  355. local outputs = {}
  356. for o=1,Outputs do
  357. local button = "P1 " .. ButtonNames[o]
  358. if network.neurons[MaxNodes+o].value > 0 then
  359. outputs[button] = true
  360. else
  361. outputs[button] = false
  362. end
  363. end
  364.  
  365. return outputs
  366. end
  367.  
  368. function crossover(g1, g2)
  369. -- Make sure g1 is the higher fitness genome
  370. if g2.fitness > g1.fitness then
  371. tempg = g1
  372. g1 = g2
  373. g2 = tempg
  374. end
  375.  
  376. local child = newGenome()
  377.  
  378. local innovations2 = {}
  379. for i=1,#g2.genes do
  380. local gene = g2.genes[i]
  381. innovations2[gene.innovation] = gene
  382. end
  383.  
  384. for i=1,#g1.genes do
  385. local gene1 = g1.genes[i]
  386. local gene2 = innovations2[gene1.innovation]
  387. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  388. table.insert(child.genes, copyGene(gene2))
  389. else
  390. table.insert(child.genes, copyGene(gene1))
  391. end
  392. end
  393.  
  394. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  395.  
  396. for mutation,rate in pairs(g1.mutationRates) do
  397. child.mutationRates[mutation] = rate
  398. end
  399.  
  400. return child
  401. end
  402.  
  403. function randomNeuron(genes, nonInput)
  404. local neurons = {}
  405. if not nonInput then
  406. for i=1,Inputs do
  407. neurons[i] = true
  408. end
  409. end
  410. for o=1,Outputs do
  411. neurons[MaxNodes+o] = true
  412. end
  413. for i=1,#genes do
  414. if (not nonInput) or genes[i].into > Inputs then
  415. neurons[genes[i].into] = true
  416. end
  417. if (not nonInput) or genes[i].out > Inputs then
  418. neurons[genes[i].out] = true
  419. end
  420. end
  421.  
  422. local count = 0
  423. for _,_ in pairs(neurons) do
  424. count = count + 1
  425. end
  426. local n = math.random(1, count)
  427.  
  428. for k,v in pairs(neurons) do
  429. n = n-1
  430. if n == 0 then
  431. return k
  432. end
  433. end
  434.  
  435. return 0
  436. end
  437.  
  438. function containsLink(genes, link)
  439. for i=1,#genes do
  440. local gene = genes[i]
  441. if gene.into == link.into and gene.out == link.out then
  442. return true
  443. end
  444. end
  445. end
  446.  
  447. function pointMutate(genome)
  448. local step = genome.mutationRates["step"]
  449.  
  450. for i=1,#genome.genes do
  451. local gene = genome.genes[i]
  452. if math.random() < PerturbChance then
  453. gene.weight = gene.weight + math.random() * step*2 - step
  454. else
  455. gene.weight = math.random()*4-2
  456. end
  457. end
  458. end
  459.  
  460. function linkMutate(genome, forceBias)
  461. local neuron1 = randomNeuron(genome.genes, false)
  462. local neuron2 = randomNeuron(genome.genes, true)
  463.  
  464. local newLink = newGene()
  465. if neuron1 <= Inputs and neuron2 <= Inputs then
  466. --Both input nodes
  467. return
  468. end
  469. if neuron2 <= Inputs then
  470. -- Swap output and input
  471. local temp = neuron1
  472. neuron1 = neuron2
  473. neuron2 = temp
  474. end
  475.  
  476. newLink.into = neuron1
  477. newLink.out = neuron2
  478. if forceBias then
  479. newLink.into = Inputs
  480. end
  481.  
  482. if containsLink(genome.genes, newLink) then
  483. return
  484. end
  485. newLink.innovation = newInnovation()
  486. newLink.weight = math.random()*4-2
  487.  
  488. table.insert(genome.genes, newLink)
  489. end
  490.  
  491. function nodeMutate(genome)
  492. if #genome.genes == 0 then
  493. return
  494. end
  495.  
  496. genome.maxneuron = genome.maxneuron + 1
  497.  
  498. local gene = genome.genes[math.random(1,#genome.genes)]
  499. if not gene.enabled then
  500. return
  501. end
  502. gene.enabled = false
  503.  
  504. local gene1 = copyGene(gene)
  505. gene1.out = genome.maxneuron
  506. gene1.weight = 1.0
  507. gene1.innovation = newInnovation()
  508. gene1.enabled = true
  509. table.insert(genome.genes, gene1)
  510.  
  511. local gene2 = copyGene(gene)
  512. gene2.into = genome.maxneuron
  513. gene2.innovation = newInnovation()
  514. gene2.enabled = true
  515. table.insert(genome.genes, gene2)
  516. end
  517.  
  518. function enableDisableMutate(genome, enable)
  519. local candidates = {}
  520. for _,gene in pairs(genome.genes) do
  521. if gene.enabled == not enable then
  522. table.insert(candidates, gene)
  523. end
  524. end
  525.  
  526. if #candidates == 0 then
  527. return
  528. end
  529.  
  530. local gene = candidates[math.random(1,#candidates)]
  531. gene.enabled = not gene.enabled
  532. end
  533.  
  534. function mutate(genome)
  535. for mutation,rate in pairs(genome.mutationRates) do
  536. if math.random(1,2) == 1 then
  537. genome.mutationRates[mutation] = 0.95*rate
  538. else
  539. genome.mutationRates[mutation] = 1.05263*rate
  540. end
  541. end
  542.  
  543. if math.random() < genome.mutationRates["connections"] then
  544. pointMutate(genome)
  545. end
  546.  
  547. local p = genome.mutationRates["link"]
  548. while p > 0 do
  549. if math.random() < p then
  550. linkMutate(genome, false)
  551. end
  552. p = p - 1
  553. end
  554.  
  555. p = genome.mutationRates["bias"]
  556. while p > 0 do
  557. if math.random() < p then
  558. linkMutate(genome, true)
  559. end
  560. p = p - 1
  561. end
  562.  
  563. p = genome.mutationRates["node"]
  564. while p > 0 do
  565. if math.random() < p then
  566. nodeMutate(genome)
  567. end
  568. p = p - 1
  569. end
  570.  
  571. p = genome.mutationRates["enable"]
  572. while p > 0 do
  573. if math.random() < p then
  574. enableDisableMutate(genome, true)
  575. end
  576. p = p - 1
  577. end
  578.  
  579. p = genome.mutationRates["disable"]
  580. while p > 0 do
  581. if math.random() < p then
  582. enableDisableMutate(genome, false)
  583. end
  584. p = p - 1
  585. end
  586. end
  587.  
  588. function disjoint(genes1, genes2)
  589. local i1 = {}
  590. for i = 1,#genes1 do
  591. local gene = genes1[i]
  592. i1[gene.innovation] = true
  593. end
  594.  
  595. local i2 = {}
  596. for i = 1,#genes2 do
  597. local gene = genes2[i]
  598. i2[gene.innovation] = true
  599. end
  600.  
  601. local disjointGenes = 0
  602. for i = 1,#genes1 do
  603. local gene = genes1[i]
  604. if not i2[gene.innovation] then
  605. disjointGenes = disjointGenes+1
  606. end
  607. end
  608.  
  609. for i = 1,#genes2 do
  610. local gene = genes2[i]
  611. if not i1[gene.innovation] then
  612. disjointGenes = disjointGenes+1
  613. end
  614. end
  615.  
  616. local n = math.max(#genes1, #genes2)
  617.  
  618. return disjointGenes / n
  619. end
  620.  
  621. function weights(genes1, genes2)
  622. local i2 = {}
  623. for i = 1,#genes2 do
  624. local gene = genes2[i]
  625. i2[gene.innovation] = gene
  626. end
  627.  
  628. local sum = 0
  629. local coincident = 0
  630. for i = 1,#genes1 do
  631. local gene = genes1[i]
  632. if i2[gene.innovation] ~= nil then
  633. local gene2 = i2[gene.innovation]
  634. sum = sum + math.abs(gene.weight - gene2.weight)
  635. coincident = coincident + 1
  636. end
  637. end
  638.  
  639. return sum / coincident
  640. end
  641.  
  642. function sameSpecies(genome1, genome2)
  643. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  644. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  645. return dd + dw < DeltaThreshold
  646. end
  647.  
  648. function rankGlobally()
  649. local global = {}
  650. for s = 1,#pool.species do
  651. local species = pool.species[s]
  652. for g = 1,#species.genomes do
  653. table.insert(global, species.genomes[g])
  654. end
  655. end
  656. table.sort(global, function (a,b)
  657. return (a.fitness < b.fitness)
  658. end)
  659.  
  660. for g=1,#global do
  661. global[g].globalRank = g
  662. end
  663. end
  664.  
  665. function calculateAverageFitness(species)
  666. local total = 0
  667.  
  668. for g=1,#species.genomes do
  669. local genome = species.genomes[g]
  670. total = total + genome.globalRank
  671. end
  672.  
  673. species.averageFitness = total / #species.genomes
  674. end
  675.  
  676. function totalAverageFitness()
  677. local total = 0
  678. for s = 1,#pool.species do
  679. local species = pool.species[s]
  680. total = total + species.averageFitness
  681. end
  682.  
  683. return total
  684. end
  685.  
  686. function cullSpecies(cutToOne)
  687. for s = 1,#pool.species do
  688. local species = pool.species[s]
  689.  
  690. table.sort(species.genomes, function (a,b)
  691. return (a.fitness > b.fitness)
  692. end)
  693.  
  694. local remaining = math.ceil(#species.genomes/2)
  695. if cutToOne then
  696. remaining = 1
  697. end
  698. while #species.genomes > remaining do
  699. table.remove(species.genomes)
  700. end
  701. end
  702. end
  703.  
  704. function breedChild(species)
  705. local child = {}
  706. if math.random() < CrossoverChance then
  707. g1 = species.genomes[math.random(1, #species.genomes)]
  708. g2 = species.genomes[math.random(1, #species.genomes)]
  709. child = crossover(g1, g2)
  710. else
  711. g = species.genomes[math.random(1, #species.genomes)]
  712. child = copyGenome(g)
  713. end
  714.  
  715. mutate(child)
  716.  
  717. return child
  718. end
  719.  
  720. function removeStaleSpecies()
  721. local survived = {}
  722.  
  723. for s = 1,#pool.species do
  724. local species = pool.species[s]
  725.  
  726. table.sort(species.genomes, function (a,b)
  727. return (a.fitness > b.fitness)
  728. end)
  729.  
  730. if species.genomes[1].fitness > species.topFitness then
  731. species.topFitness = species.genomes[1].fitness
  732. species.staleness = 0
  733. else
  734. species.staleness = species.staleness + 1
  735. end
  736. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  737. table.insert(survived, species)
  738. end
  739. end
  740.  
  741. pool.species = survived
  742. end
  743.  
  744. function removeWeakSpecies()
  745. local survived = {}
  746.  
  747. local sum = totalAverageFitness()
  748. for s = 1,#pool.species do
  749. local species = pool.species[s]
  750. breed = math.floor(species.averageFitness / sum * Population)
  751. if breed >= 1 then
  752. table.insert(survived, species)
  753. end
  754. end
  755.  
  756. pool.species = survived
  757. end
  758.  
  759.  
  760. function addToSpecies(child)
  761. local foundSpecies = false
  762. for s=1,#pool.species do
  763. local species = pool.species[s]
  764. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  765. table.insert(species.genomes, child)
  766. foundSpecies = true
  767. end
  768. end
  769.  
  770. if not foundSpecies then
  771. local childSpecies = newSpecies()
  772. table.insert(childSpecies.genomes, child)
  773. table.insert(pool.species, childSpecies)
  774. end
  775. end
  776.  
  777. function newGeneration()
  778. cullSpecies(false) -- Cull the bottom half of each species
  779. rankGlobally()
  780. removeStaleSpecies()
  781. rankGlobally()
  782. for s = 1,#pool.species do
  783. local species = pool.species[s]
  784. calculateAverageFitness(species)
  785. end
  786. removeWeakSpecies()
  787. local sum = totalAverageFitness()
  788. local children = {}
  789. for s = 1,#pool.species do
  790. local species = pool.species[s]
  791. breed = math.floor(species.averageFitness / sum * Population) - 1
  792. for i=1,breed do
  793. table.insert(children, breedChild(species))
  794. end
  795. end
  796. cullSpecies(true) -- Cull all but the top member of each species
  797. while #children + #pool.species < Population do
  798. local species = pool.species[math.random(1, #pool.species)]
  799. table.insert(children, breedChild(species))
  800. end
  801. for c=1,#children do
  802. local child = children[c]
  803. addToSpecies(child)
  804. end
  805.  
  806. pool.generation = pool.generation + 1
  807.  
  808. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  809. end
  810.  
  811. function initializePool()
  812. pool = newPool()
  813.  
  814. for i=1,Population do
  815. basic = basicGenome()
  816. addToSpecies(basic)
  817. end
  818.  
  819. initializeRun()
  820. end
  821.  
  822. function clearJoypad()
  823. controller = {}
  824. for b = 1,#ButtonNames do
  825. controller["P1 " .. ButtonNames[b]] = false
  826. end
  827. joypad.set(controller)
  828. end
  829.  
  830. function initializeRun()
  831. savestate.load(Filename);
  832. rightmost = 0
  833. pool.currentFrame = 0
  834. timeout = TimeoutConstant
  835. clearJoypad()
  836.  
  837. local species = pool.species[pool.currentSpecies]
  838. local genome = species.genomes[pool.currentGenome]
  839. generateNetwork(genome)
  840. evaluateCurrent()
  841. end
  842.  
  843. function evaluateCurrent()
  844. local species = pool.species[pool.currentSpecies]
  845. local genome = species.genomes[pool.currentGenome]
  846.  
  847. inputs = getInputs()
  848. controller = evaluateNetwork(genome.network, inputs)
  849.  
  850. if controller["P1 Left"] and controller["P1 Right"] then
  851. controller["P1 Left"] = false
  852. controller["P1 Right"] = false
  853. end
  854. if controller["P1 Up"] and controller["P1 Down"] then
  855. controller["P1 Up"] = false
  856. controller["P1 Down"] = false
  857. end
  858.  
  859. joypad.set(controller)
  860. end
  861.  
  862. if pool == nil then
  863. initializePool()
  864. end
  865.  
  866.  
  867. function nextGenome()
  868. pool.currentGenome = pool.currentGenome + 1
  869. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  870. pool.currentGenome = 1
  871. pool.currentSpecies = pool.currentSpecies+1
  872. if pool.currentSpecies > #pool.species then
  873. newGeneration()
  874. pool.currentSpecies = 1
  875. end
  876. end
  877. end
  878.  
  879. function fitnessAlreadyMeasured()
  880. local species = pool.species[pool.currentSpecies]
  881. local genome = species.genomes[pool.currentGenome]
  882.  
  883. return genome.fitness ~= 0
  884. end
  885.  
  886. function displayGenome(genome)
  887. local network = genome.network
  888. local cells = {}
  889. local i = 1
  890. local cell = {}
  891. for dy=-BoxRadius,BoxRadius do
  892. for dx=-BoxRadius,BoxRadius do
  893. cell = {}
  894. cell.x = 50+5*dx
  895. cell.y = 70+5*dy
  896. cell.value = network.neurons[i].value
  897. cells[i] = cell
  898. i = i + 1
  899. end
  900. end
  901. local biasCell = {}
  902. biasCell.x = 80
  903. biasCell.y = 110
  904. biasCell.value = network.neurons[Inputs].value
  905. cells[Inputs] = biasCell
  906.  
  907. for o = 1,Outputs do
  908. cell = {}
  909. cell.x = 220
  910. cell.y = 30 + 8 * o
  911. cell.value = network.neurons[MaxNodes + o].value
  912. cells[MaxNodes+o] = cell
  913. local color
  914. if cell.value > 0 then
  915. color = 0xFF0000FF
  916. else
  917. color = 0xFF000000
  918. end
  919. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  920. end
  921.  
  922. for n,neuron in pairs(network.neurons) do
  923. cell = {}
  924. if n > Inputs and n <= MaxNodes then
  925. cell.x = 140
  926. cell.y = 40
  927. cell.value = neuron.value
  928. cells[n] = cell
  929. end
  930. end
  931.  
  932. for n=1,4 do
  933. for _,gene in pairs(genome.genes) do
  934. if gene.enabled then
  935. local c1 = cells[gene.into]
  936. local c2 = cells[gene.out]
  937. if gene.into > Inputs and gene.into <= MaxNodes then
  938. c1.x = 0.75*c1.x + 0.25*c2.x
  939. if c1.x >= c2.x then
  940. c1.x = c1.x - 40
  941. end
  942. if c1.x < 90 then
  943. c1.x = 90
  944. end
  945.  
  946. if c1.x > 220 then
  947. c1.x = 220
  948. end
  949. c1.y = 0.75*c1.y + 0.25*c2.y
  950.  
  951. end
  952. if gene.out > Inputs and gene.out <= MaxNodes then
  953. c2.x = 0.25*c1.x + 0.75*c2.x
  954. if c1.x >= c2.x then
  955. c2.x = c2.x + 40
  956. end
  957. if c2.x < 90 then
  958. c2.x = 90
  959. end
  960. if c2.x > 220 then
  961. c2.x = 220
  962. end
  963. c2.y = 0.25*c1.y + 0.75*c2.y
  964. end
  965. end
  966. end
  967. end
  968.  
  969. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  970. for n,cell in pairs(cells) do
  971. if n > Inputs or cell.value ~= 0 then
  972. local color = math.floor((cell.value+1)/2*256)
  973. if color > 255 then color = 255 end
  974. if color < 0 then color = 0 end
  975. local opacity = 0xFF000000
  976. if cell.value == 0 then
  977. opacity = 0x50000000
  978. end
  979. color = opacity + color*0x10000 + color*0x100 + color
  980. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  981. end
  982. end
  983. for _,gene in pairs(genome.genes) do
  984. if gene.enabled then
  985. local c1 = cells[gene.into]
  986. local c2 = cells[gene.out]
  987. local opacity = 0xA0000000
  988. if c1.value == 0 then
  989. opacity = 0x20000000
  990. end
  991.  
  992. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  993. if gene.weight > 0 then
  994. color = opacity + 0x8000 + 0x10000*color
  995. else
  996. color = opacity + 0x800000 + 0x100*color
  997. end
  998. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  999. end
  1000. end
  1001.  
  1002. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1003.  
  1004. if forms.ischecked(showMutationRates) then
  1005. local pos = 100
  1006. for mutation,rate in pairs(genome.mutationRates) do
  1007. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1008. pos = pos + 8
  1009. end
  1010. end
  1011. end
  1012.  
  1013. function writeFile(filename)
  1014. local file = io.open(filename, "w")
  1015. file:write(pool.generation .. "\n")
  1016. file:write(pool.maxFitness .. "\n")
  1017. file:write(#pool.species .. "\n")
  1018. for n,species in pairs(pool.species) do
  1019. file:write(species.topFitness .. "\n")
  1020. file:write(species.staleness .. "\n")
  1021. file:write(#species.genomes .. "\n")
  1022. for m,genome in pairs(species.genomes) do
  1023. file:write(genome.fitness .. "\n")
  1024. file:write(genome.maxneuron .. "\n")
  1025. for mutation,rate in pairs(genome.mutationRates) do
  1026. file:write(mutation .. "\n")
  1027. file:write(rate .. "\n")
  1028. end
  1029. file:write("done\n")
  1030.  
  1031. file:write(#genome.genes .. "\n")
  1032. for l,gene in pairs(genome.genes) do
  1033. file:write(gene.into .. " ")
  1034. file:write(gene.out .. " ")
  1035. file:write(gene.weight .. " ")
  1036. file:write(gene.innovation .. " ")
  1037. if(gene.enabled) then
  1038. file:write("1\n")
  1039. else
  1040. file:write("0\n")
  1041. end
  1042. end
  1043. end
  1044. end
  1045. file:close()
  1046. end
  1047.  
  1048. function savePool()
  1049. local filename = forms.gettext(saveLoadFile)
  1050. writeFile(filename)
  1051. end
  1052.  
  1053. function loadFile(filename)
  1054. local file = io.open(filename, "r")
  1055. pool = newPool()
  1056. pool.generation = file:read("*number")
  1057. pool.maxFitness = file:read("*number")
  1058. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1059. local numSpecies = file:read("*number")
  1060. for s=1,numSpecies do
  1061. local species = newSpecies()
  1062. table.insert(pool.species, species)
  1063. species.topFitness = file:read("*number")
  1064. species.staleness = file:read("*number")
  1065. local numGenomes = file:read("*number")
  1066. for g=1,numGenomes do
  1067. local genome = newGenome()
  1068. table.insert(species.genomes, genome)
  1069. genome.fitness = file:read("*number")
  1070. genome.maxneuron = file:read("*number")
  1071. local line = file:read("*line")
  1072. while line ~= "done" do
  1073. genome.mutationRates[line] = file:read("*number")
  1074. line = file:read("*line")
  1075. end
  1076. local numGenes = file:read("*number")
  1077. for n=1,numGenes do
  1078. local gene = newGene()
  1079. table.insert(genome.genes, gene)
  1080. local enabled
  1081. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1082. if enabled == 0 then
  1083. gene.enabled = false
  1084. else
  1085. gene.enabled = true
  1086. end
  1087.  
  1088. end
  1089. end
  1090. end
  1091. file:close()
  1092.  
  1093. while fitnessAlreadyMeasured() do
  1094. nextGenome()
  1095. end
  1096. initializeRun()
  1097. pool.currentFrame = pool.currentFrame + 1
  1098. end
  1099.  
  1100. function loadPool()
  1101. local filename = forms.gettext(saveLoadFile)
  1102. loadFile(filename)
  1103. end
  1104.  
  1105. function playTop()
  1106. local maxfitness = 0
  1107. local maxs, maxg
  1108. for s,species in pairs(pool.species) do
  1109. for g,genome in pairs(species.genomes) do
  1110. if genome.fitness > maxfitness then
  1111. maxfitness = genome.fitness
  1112. maxs = s
  1113. maxg = g
  1114. end
  1115. end
  1116. end
  1117.  
  1118. pool.currentSpecies = maxs
  1119. pool.currentGenome = maxg
  1120. pool.maxFitness = maxfitness
  1121. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1122. initializeRun()
  1123. pool.currentFrame = pool.currentFrame + 1
  1124. return
  1125. end
  1126.  
  1127. function onExit()
  1128. forms.destroy(form)
  1129. end
  1130.  
  1131. writeFile("temp.pool")
  1132.  
  1133. event.onexit(onExit)
  1134.  
  1135. form = forms.newform(200, 260, "Fitness")
  1136. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1137. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1138. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1139. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1140. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1141. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1142. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1143. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1144. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1145. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1146.  
  1147.  
  1148. while true do
  1149. local backgroundColor = 0xD0FFFFFF
  1150. if not forms.ischecked(hideBanner) then
  1151. gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1152. end
  1153.  
  1154. local species = pool.species[pool.currentSpecies]
  1155. local genome = species.genomes[pool.currentGenome]
  1156.  
  1157. if forms.ischecked(showNetwork) then
  1158. displayGenome(genome)
  1159. end
  1160.  
  1161. if pool.currentFrame%5 == 0 then
  1162. evaluateCurrent()
  1163. end
  1164.  
  1165. joypad.set(controller)
  1166.  
  1167. getPositions()
  1168. if marioX > rightmost then
  1169. rightmost = marioX
  1170. timeout = TimeoutConstant
  1171. end
  1172.  
  1173. timeout = timeout - 1
  1174.  
  1175.  
  1176. local timeoutBonus = pool.currentFrame / 4
  1177. if timeout + timeoutBonus <= 0 then
  1178. local fitness = rightmost - pool.currentFrame / 2
  1179. if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1180. fitness = fitness + 1000
  1181. end
  1182. if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1183. fitness = fitness + 1000
  1184. end
  1185. if fitness == 0 then
  1186. fitness = -1
  1187. end
  1188. genome.fitness = fitness
  1189.  
  1190. if fitness > pool.maxFitness then
  1191. pool.maxFitness = fitness
  1192. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1193. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1194. end
  1195.  
  1196. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1197. pool.currentSpecies = 1
  1198. pool.currentGenome = 1
  1199. while fitnessAlreadyMeasured() do
  1200. nextGenome()
  1201. end
  1202. initializeRun()
  1203. end
  1204.  
  1205. local measured = 0
  1206. local total = 0
  1207. for _,species in pairs(pool.species) do
  1208. for _,genome in pairs(species.genomes) do
  1209. total = total + 1
  1210. if genome.fitness ~= 0 then
  1211. measured = measured + 1
  1212. end
  1213. end
  1214. end
  1215. if not forms.ischecked(hideBanner) then
  1216. gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1217. gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
  1218. gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1219. end
  1220.  
  1221. pool.currentFrame = pool.currentFrame + 1
  1222.  
  1223. emu.frameadvance();
  1224. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement