Advertisement
Guest User

Untitled

a guest
Aug 18th, 2015
233
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 31.24 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8. Filename = "DP1.state"
  9. ButtonNames = {
  10. "A",
  11. "B",
  12. "X",
  13. "Y",
  14. "Up",
  15. "Down",
  16. "Left",
  17. "Right",
  18. }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20. Filename = "SMB1-1.state"
  21. ButtonNames = {
  22. "A",
  23. "B",
  24. "Up",
  25. "Down",
  26. "Left",
  27. "Right",
  28. }
  29. end
  30.  
  31.  
  32.  
  33.  
  34. Population = 1000 -- edit 300
  35. TimeoutConstant = 100 --edit 20
  36.  
  37.  
  38. TotalTries = 1
  39. UniqueSpeciesId = 1 -- shows as 'us', counts up unique ids and stores them to species,
  40. CurrentTotalFitness = 0
  41. CurrentAF = 0
  42. LastAF = 0
  43. MaxAF = -1000
  44. AFStaleFor = 0
  45. FitnessesMeasured = 0
  46.  
  47.  
  48. BoxRadius = 6
  49. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  50.  
  51. Inputs = InputSize+1
  52. Outputs = #ButtonNames
  53.  
  54.  
  55. DeltaDisjoint = 2.0
  56. DeltaWeights = 0.4
  57. DeltaThreshold = 1.0
  58.  
  59. StaleSpecies = 15
  60.  
  61. MutateConnectionsChance = 0.25
  62. PerturbChance = 0.90
  63. CrossoverChance = 0.75
  64. LinkMutationChance = 2.0
  65. NodeMutationChance = 0.50
  66. BiasMutationChance = 0.40
  67. StepSize = 0.1
  68. DisableMutationChance = 0.4
  69. EnableMutationChance = 0.2
  70.  
  71.  
  72.  
  73. MaxNodes = 1000000
  74.  
  75. function getPositions()
  76. if gameinfo.getromname() == "Super Mario World (USA)" then
  77. marioX = memory.read_s16_le(0x94)
  78. marioY = memory.read_s16_le(0x96)
  79.  
  80. local layer1x = memory.read_s16_le(0x1A);
  81. local layer1y = memory.read_s16_le(0x1C);
  82.  
  83. screenX = marioX-layer1x
  84. screenY = marioY-layer1y
  85. elseif gameinfo.getromname() == "Super Mario Bros." then
  86. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  87. marioY = memory.readbyte(0x03B8)+16
  88.  
  89. screenX = memory.readbyte(0x03AD)
  90. screenY = memory.readbyte(0x03B8)
  91. end
  92. end
  93.  
  94. function getTile(dx, dy)
  95. if gameinfo.getromname() == "Super Mario World (USA)" then
  96. x = math.floor((marioX+dx+8)/16)
  97. y = math.floor((marioY+dy)/16)
  98.  
  99. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  100. elseif gameinfo.getromname() == "Super Mario Bros." then
  101. local x = marioX + dx + 8
  102. local y = marioY + dy - 16
  103. local page = math.floor(x/256)%2
  104.  
  105. local subx = math.floor((x%256)/16)
  106. local suby = math.floor((y - 32)/16)
  107. local addr = 0x500 + page*13*16+suby*16+subx
  108.  
  109. if suby >= 13 or suby < 0 then
  110. return 0
  111. end
  112.  
  113. if memory.readbyte(addr) ~= 0 then
  114. return 1
  115. else
  116. return 0
  117. end
  118. end
  119. end
  120.  
  121. function getSprites()
  122. if gameinfo.getromname() == "Super Mario World (USA)" then
  123. local sprites = {}
  124. for slot=0,11 do
  125. local status = memory.readbyte(0x14C8+slot)
  126. if status ~= 0 then
  127. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  128. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  129. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  130. end
  131. end
  132.  
  133. return sprites
  134. elseif gameinfo.getromname() == "Super Mario Bros." then
  135. local sprites = {}
  136. for slot=0,4 do
  137. local enemy = memory.readbyte(0xF+slot)
  138. if enemy ~= 0 then
  139. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  140. local ey = memory.readbyte(0xCF + slot)+24
  141. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  142. end
  143. end
  144.  
  145. return sprites
  146. end
  147. end
  148.  
  149. function getExtendedSprites()
  150. if gameinfo.getromname() == "Super Mario World (USA)" then
  151. local extended = {}
  152. for slot=0,11 do
  153. local number = memory.readbyte(0x170B+slot)
  154. if number ~= 0 then
  155. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  156. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  157. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  158. end
  159. end
  160.  
  161. return extended
  162. elseif gameinfo.getromname() == "Super Mario Bros." then
  163. return {}
  164. end
  165. end
  166.  
  167. function getInputs()
  168. getPositions()
  169.  
  170. sprites = getSprites()
  171. extended = getExtendedSprites()
  172.  
  173. local inputs = {}
  174.  
  175. for dy=-BoxRadius*16,BoxRadius*16,16 do
  176. for dx=-BoxRadius*16,BoxRadius*16,16 do
  177. inputs[#inputs+1] = 0
  178.  
  179. tile = getTile(dx, dy)
  180. if tile == 1 and marioY+dy < 0x1B0 then
  181. inputs[#inputs] = 1
  182. end
  183.  
  184. for i = 1,#sprites do
  185. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  186. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  187. if distx <= 8 and disty <= 8 then
  188. inputs[#inputs] = -1
  189. end
  190. end
  191.  
  192. for i = 1,#extended do
  193. distx = math.abs(extended[i]["x"] - (marioX+dx))
  194. disty = math.abs(extended[i]["y"] - (marioY+dy))
  195. if distx < 8 and disty < 8 then
  196. inputs[#inputs] = -1
  197. end
  198. end
  199. end
  200. end
  201.  
  202. --mariovx = memory.read_s8(0x7B)
  203. --mariovy = memory.read_s8(0x7D)
  204.  
  205. return inputs
  206. end
  207.  
  208. function sigmoid(x)
  209. return 2/(1+math.exp(-4.9*x))-1
  210. end
  211.  
  212. function newInnovation()
  213. pool.innovation = pool.innovation + 1
  214. return pool.innovation
  215. end
  216.  
  217. function newPool()
  218. local pool = {}
  219. pool.species = {}
  220. pool.generation = 0
  221. pool.innovation = Outputs
  222. pool.currentSpecies = 1
  223. pool.currentGenome = 1
  224. pool.currentFrame = 0
  225. pool.maxFitness = 0
  226.  
  227. return pool
  228. end
  229.  
  230. function newSpecies()
  231. local species = {}
  232. species.topFitness = 0
  233. species.staleness = 0
  234. species.genomes = {}
  235. species.averageFitness = 0
  236. species.uniqueId = UniqueSpeciesId
  237. species.bornGen = pool.generation
  238. species.Age = 0
  239. species.AF = 0;
  240. return species
  241. end
  242.  
  243. function newGenome()
  244. local genome = {}
  245. genome.genes = {}
  246. genome.fitness = 0
  247. genome.adjustedFitness = 0
  248. genome.network = {}
  249. genome.maxneuron = 0
  250. genome.globalRank = 0
  251. genome.mutationRates = {}
  252. genome.mutationRates["connections"] = MutateConnectionsChance
  253. genome.mutationRates["link"] = LinkMutationChance
  254. genome.mutationRates["bias"] = BiasMutationChance
  255. genome.mutationRates["node"] = NodeMutationChance
  256. genome.mutationRates["enable"] = EnableMutationChance
  257. genome.mutationRates["disable"] = DisableMutationChance
  258. genome.mutationRates["step"] = StepSize
  259.  
  260. return genome
  261. end
  262.  
  263. function copyGenome(genome)
  264. local genome2 = newGenome()
  265. for g=1,#genome.genes do
  266. table.insert(genome2.genes, copyGene(genome.genes[g]))
  267. end
  268. genome2.maxneuron = genome.maxneuron
  269. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  270. genome2.mutationRates["link"] = genome.mutationRates["link"]
  271. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  272. genome2.mutationRates["node"] = genome.mutationRates["node"]
  273. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  274. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  275.  
  276. return genome2
  277. end
  278.  
  279. function basicGenome()
  280. local genome = newGenome()
  281. local innovation = 1
  282.  
  283. genome.maxneuron = Inputs
  284. mutate(genome)
  285.  
  286. return genome
  287. end
  288.  
  289. function newGene()
  290. local gene = {}
  291. gene.into = 0
  292. gene.out = 0
  293. gene.weight = 0.0
  294. gene.enabled = true
  295. gene.innovation = 0
  296.  
  297. return gene
  298. end
  299.  
  300. function copyGene(gene)
  301. local gene2 = newGene()
  302. gene2.into = gene.into
  303. gene2.out = gene.out
  304. gene2.weight = gene.weight
  305. gene2.enabled = gene.enabled
  306. gene2.innovation = gene.innovation
  307.  
  308. return gene2
  309. end
  310.  
  311. function newNeuron()
  312. local neuron = {}
  313. neuron.incoming = {}
  314. neuron.value = 0.0
  315.  
  316. return neuron
  317. end
  318.  
  319. function generateNetwork(genome)
  320. local network = {}
  321. network.neurons = {}
  322.  
  323. for i=1,Inputs do
  324. network.neurons[i] = newNeuron()
  325. end
  326.  
  327. for o=1,Outputs do
  328. network.neurons[MaxNodes+o] = newNeuron()
  329. end
  330.  
  331. table.sort(genome.genes, function (a,b)
  332. return (a.out < b.out)
  333. end)
  334. for i=1,#genome.genes do
  335. local gene = genome.genes[i]
  336. if gene.enabled then
  337. if network.neurons[gene.out] == nil then
  338. network.neurons[gene.out] = newNeuron()
  339. end
  340. local neuron = network.neurons[gene.out]
  341. table.insert(neuron.incoming, gene)
  342. if network.neurons[gene.into] == nil then
  343. network.neurons[gene.into] = newNeuron()
  344. end
  345. end
  346. end
  347.  
  348. genome.network = network
  349. end
  350.  
  351. function evaluateNetwork(network, inputs)
  352. table.insert(inputs, 1)
  353. if #inputs ~= Inputs then
  354. console.writeline("Incorrect number of neural network inputs.")
  355. return {}
  356. end
  357.  
  358. for i=1,Inputs do
  359. network.neurons[i].value = inputs[i]
  360. end
  361.  
  362. for _,neuron in pairs(network.neurons) do
  363. local sum = 0
  364. for j = 1,#neuron.incoming do
  365. local incoming = neuron.incoming[j]
  366. local other = network.neurons[incoming.into]
  367. sum = sum + incoming.weight * other.value
  368. end
  369.  
  370. if #neuron.incoming > 0 then
  371. neuron.value = sigmoid(sum)
  372. end
  373. end
  374.  
  375. local outputs = {}
  376. for o=1,Outputs do
  377. local button = "P1 " .. ButtonNames[o]
  378. if network.neurons[MaxNodes+o].value > 0 then
  379. outputs[button] = true
  380. else
  381. outputs[button] = false
  382. end
  383. end
  384.  
  385. return outputs
  386. end
  387.  
  388. function crossover(g1, g2)
  389. -- Make sure g1 is the higher fitness genome
  390. if g2.fitness > g1.fitness then
  391. tempg = g1
  392. g1 = g2
  393. g2 = tempg
  394. end
  395.  
  396. local child = newGenome()
  397.  
  398. local innovations2 = {}
  399. for i=1,#g2.genes do
  400. local gene = g2.genes[i]
  401. innovations2[gene.innovation] = gene
  402. end
  403.  
  404. for i=1,#g1.genes do
  405. local gene1 = g1.genes[i]
  406. local gene2 = innovations2[gene1.innovation]
  407. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  408. table.insert(child.genes, copyGene(gene2))
  409. else
  410. table.insert(child.genes, copyGene(gene1))
  411. end
  412. end
  413.  
  414. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  415.  
  416. for mutation,rate in pairs(g1.mutationRates) do
  417. child.mutationRates[mutation] = rate
  418. end
  419.  
  420. return child
  421. end
  422.  
  423. function randomNeuron(genes, nonInput)
  424. local neurons = {}
  425. if not nonInput then
  426. for i=1,Inputs do
  427. neurons[i] = true
  428. end
  429. end
  430. for o=1,Outputs do
  431. neurons[MaxNodes+o] = true
  432. end
  433. for i=1,#genes do
  434. if (not nonInput) or genes[i].into > Inputs then
  435. neurons[genes[i].into] = true
  436. end
  437. if (not nonInput) or genes[i].out > Inputs then
  438. neurons[genes[i].out] = true
  439. end
  440. end
  441.  
  442. local count = 0
  443. for _,_ in pairs(neurons) do
  444. count = count + 1
  445. end
  446. local n = math.random(1, count)
  447.  
  448. for k,v in pairs(neurons) do
  449. n = n-1
  450. if n == 0 then
  451. return k
  452. end
  453. end
  454.  
  455. return 0
  456. end
  457.  
  458. function containsLink(genes, link)
  459. for i=1,#genes do
  460. local gene = genes[i]
  461. if gene.into == link.into and gene.out == link.out then
  462. return true
  463. end
  464. end
  465. end
  466.  
  467. function pointMutate(genome)
  468. local step = genome.mutationRates["step"]
  469.  
  470. for i=1,#genome.genes do
  471. local gene = genome.genes[i]
  472. if math.random() < PerturbChance then
  473. gene.weight = gene.weight + math.random() * step*2 - step
  474. else
  475. gene.weight = math.random()*4-2
  476. end
  477. end
  478. end
  479.  
  480. function linkMutate(genome, forceBias)
  481. local neuron1 = randomNeuron(genome.genes, false)
  482. local neuron2 = randomNeuron(genome.genes, true)
  483.  
  484. local newLink = newGene()
  485. if neuron1 <= Inputs and neuron2 <= Inputs then
  486. --Both input nodes
  487. return
  488. end
  489. if neuron2 <= Inputs then
  490. -- Swap output and input
  491. local temp = neuron1
  492. neuron1 = neuron2
  493. neuron2 = temp
  494. end
  495.  
  496. newLink.into = neuron1
  497. newLink.out = neuron2
  498. if forceBias then
  499. newLink.into = Inputs
  500. end
  501.  
  502. if containsLink(genome.genes, newLink) then
  503. return
  504. end
  505. newLink.innovation = newInnovation()
  506. newLink.weight = math.random()*4-2
  507.  
  508. table.insert(genome.genes, newLink)
  509. end
  510.  
  511. function nodeMutate(genome)
  512. if #genome.genes == 0 then
  513. return
  514. end
  515.  
  516. genome.maxneuron = genome.maxneuron + 1
  517.  
  518. local gene = genome.genes[math.random(1,#genome.genes)]
  519. if not gene.enabled then
  520. return
  521. end
  522. gene.enabled = false
  523.  
  524. local gene1 = copyGene(gene)
  525. gene1.out = genome.maxneuron
  526. gene1.weight = 1.0
  527. gene1.innovation = newInnovation()
  528. gene1.enabled = true
  529. table.insert(genome.genes, gene1)
  530.  
  531. local gene2 = copyGene(gene)
  532. gene2.into = genome.maxneuron
  533. gene2.innovation = newInnovation()
  534. gene2.enabled = true
  535. table.insert(genome.genes, gene2)
  536. end
  537.  
  538. function enableDisableMutate(genome, enable)
  539. local candidates = {}
  540. for _,gene in pairs(genome.genes) do
  541. if gene.enabled == not enable then
  542. table.insert(candidates, gene)
  543. end
  544. end
  545.  
  546. if #candidates == 0 then
  547. return
  548. end
  549.  
  550. local gene = candidates[math.random(1,#candidates)]
  551. gene.enabled = not gene.enabled
  552. end
  553.  
  554. function mutate(genome)
  555. for mutation,rate in pairs(genome.mutationRates) do
  556. if math.random(1,2) == 1 then
  557. genome.mutationRates[mutation] = 0.95*rate
  558. else
  559. genome.mutationRates[mutation] = 1.05263*rate
  560. end
  561. end
  562.  
  563. if math.random() < genome.mutationRates["connections"] then
  564. pointMutate(genome)
  565. end
  566.  
  567. local p = genome.mutationRates["link"]
  568. while p > 0 do
  569. if math.random() < p then
  570. linkMutate(genome, false)
  571. end
  572. p = p - 1
  573. end
  574.  
  575. p = genome.mutationRates["bias"]
  576. while p > 0 do
  577. if math.random() < p then
  578. linkMutate(genome, true)
  579. end
  580. p = p - 1
  581. end
  582.  
  583. p = genome.mutationRates["node"]
  584. while p > 0 do
  585. if math.random() < p then
  586. nodeMutate(genome)
  587. end
  588. p = p - 1
  589. end
  590.  
  591. p = genome.mutationRates["enable"]
  592. while p > 0 do
  593. if math.random() < p then
  594. enableDisableMutate(genome, true)
  595. end
  596. p = p - 1
  597. end
  598.  
  599. p = genome.mutationRates["disable"]
  600. while p > 0 do
  601. if math.random() < p then
  602. enableDisableMutate(genome, false)
  603. end
  604. p = p - 1
  605. end
  606. end
  607.  
  608. function disjoint(genes1, genes2)
  609. local i1 = {}
  610. for i = 1,#genes1 do
  611. local gene = genes1[i]
  612. i1[gene.innovation] = true
  613. end
  614.  
  615. local i2 = {}
  616. for i = 1,#genes2 do
  617. local gene = genes2[i]
  618. i2[gene.innovation] = true
  619. end
  620.  
  621. local disjointGenes = 0
  622. for i = 1,#genes1 do
  623. local gene = genes1[i]
  624. if not i2[gene.innovation] then
  625. disjointGenes = disjointGenes+1
  626. end
  627. end
  628.  
  629. for i = 1,#genes2 do
  630. local gene = genes2[i]
  631. if not i1[gene.innovation] then
  632. disjointGenes = disjointGenes+1
  633. end
  634. end
  635.  
  636. local n = math.max(#genes1, #genes2)
  637.  
  638. return disjointGenes / n
  639. end
  640.  
  641. function weights(genes1, genes2)
  642. local i2 = {}
  643. for i = 1,#genes2 do
  644. local gene = genes2[i]
  645. i2[gene.innovation] = gene
  646. end
  647.  
  648. local sum = 0
  649. local coincident = 0
  650. for i = 1,#genes1 do
  651. local gene = genes1[i]
  652. if i2[gene.innovation] ~= nil then
  653. local gene2 = i2[gene.innovation]
  654. sum = sum + math.abs(gene.weight - gene2.weight)
  655. coincident = coincident + 1
  656. end
  657. end
  658.  
  659. return sum / coincident
  660. end
  661.  
  662. function sameSpecies(genome1, genome2)
  663. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  664. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  665. return dd + dw < DeltaThreshold
  666. end
  667.  
  668. function rankGlobally()
  669. local global = {}
  670. for s = 1,#pool.species do
  671. local species = pool.species[s]
  672. for g = 1,#species.genomes do
  673. table.insert(global, species.genomes[g])
  674. end
  675. end
  676. table.sort(global, function (a,b)
  677. return (a.fitness < b.fitness)
  678. end)
  679.  
  680. for g=1,#global do
  681. global[g].globalRank = g
  682. end
  683. end
  684.  
  685. function calculateAverageFitness(species)
  686. local total = 0
  687.  
  688. for g=1,#species.genomes do
  689. local genome = species.genomes[g]
  690. total = total + genome.globalRank
  691. end
  692.  
  693. species.averageFitness = total / #species.genomes
  694. end
  695.  
  696. function totalAverageFitness()
  697. local total = 0
  698. for s = 1,#pool.species do
  699. local species = pool.species[s]
  700. total = total + species.averageFitness
  701. end
  702. LastSpecies = #pool.species
  703. return total
  704. end
  705.  
  706. function cullSpecies(cutToOne)
  707. for s = 1,#pool.species do
  708. local species = pool.species[s]
  709.  
  710. table.sort(species.genomes, function (a,b)
  711. return (a.fitness > b.fitness)
  712. end)
  713.  
  714. local remaining = math.ceil(#species.genomes/2)
  715. if cutToOne then
  716. remaining = 1
  717. end
  718. while #species.genomes > remaining do
  719. table.remove(species.genomes)
  720. end
  721. end
  722. end
  723.  
  724. function breedChild(species)
  725. local child = {}
  726. if math.random() < CrossoverChance then
  727. g1 = species.genomes[math.random(1, #species.genomes)]
  728. g2 = species.genomes[math.random(1, #species.genomes)]
  729. child = crossover(g1, g2)
  730. else
  731. g = species.genomes[math.random(1, #species.genomes)]
  732. child = copyGenome(g)
  733. end
  734.  
  735. mutate(child)
  736.  
  737. return child
  738. end
  739.  
  740. function removeStaleSpecies()
  741. local survived = {}
  742.  
  743. for s = 1,#pool.species do
  744. local species = pool.species[s]
  745.  
  746. table.sort(species.genomes, function (a,b)
  747. return (a.fitness > b.fitness)
  748. end)
  749.  
  750. if species.genomes[1].fitness > species.topFitness then
  751. species.topFitness = species.genomes[1].fitness
  752. species.staleness = 0
  753. else
  754. species.staleness = species.staleness + 1
  755. end
  756. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  757. table.insert(survived, species)
  758. end
  759. end
  760.  
  761. pool.species = survived
  762. end
  763.  
  764. function removeWeakSpecies()
  765. local survived = {}
  766.  
  767. local sum = totalAverageFitness()
  768. for s = 1,#pool.species do
  769. local species = pool.species[s]
  770. breed = math.floor(species.averageFitness / sum * Population)
  771. species.AF = (species.averageFitness / sum * Population)
  772. if breed >= 1 then
  773. table.insert(survived, species)
  774. end
  775. end
  776.  
  777. pool.species = survived
  778. end
  779.  
  780.  
  781. function addToSpecies(child)
  782. local foundSpecies = false
  783. for s=1,#pool.species do
  784. local species = pool.species[s]
  785. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  786. table.insert(species.genomes, child)
  787. foundSpecies = true
  788. end
  789. end
  790.  
  791. if not foundSpecies then
  792. local childSpecies = newSpecies()
  793. UniqueSpeciesId = UniqueSpeciesId + 1 -- edit line
  794. table.insert(childSpecies.genomes, child)
  795. table.insert(pool.species, childSpecies)
  796. end
  797. end
  798.  
  799. function newGeneration()
  800. LastAF = CurrentTotalFitness / FitnessesMeasured
  801. if LastAF > MaxAF then
  802. MaxAF = LastAF
  803. AFStaleFor = 0
  804. else
  805. AFStaleFor = AFStaleFor + 1
  806. end
  807. CurrentTotalFitness = 0
  808. FitnessesMeasured = 0
  809. cullSpecies(false) -- Cull the bottom half of each species
  810. rankGlobally()
  811. removeStaleSpecies()
  812. rankGlobally()
  813. pool.generation = pool.generation + 1
  814. for s = 1,#pool.species do
  815. local species = pool.species[s]
  816. calculateAverageFitness(species)
  817. species.Age = pool.generation - species.bornGen
  818. end
  819. removeWeakSpecies()
  820. local sum = totalAverageFitness()
  821. local children = {}
  822. for s = 1,#pool.species do
  823. local species = pool.species[s]
  824. breed = math.floor(species.averageFitness / sum * Population) - 1
  825. for i=1,breed do
  826. table.insert(children, breedChild(species))
  827. end
  828. end
  829. cullSpecies(true) -- Cull all but the top member of each species
  830. while #children + #pool.species < Population do
  831. local species = pool.species[math.random(1, #pool.species)]
  832. table.insert(children, breedChild(species))
  833. end
  834.  
  835.  
  836.  
  837. for c=1,#children do
  838. local child = children[c]
  839. addToSpecies(child)
  840. end
  841.  
  842.  
  843.  
  844.  
  845. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  846. end
  847.  
  848. function initializePool()
  849. UniqueSpeciesId = 1
  850. CurrentTotalFitness = 0
  851. FitnessesMeasured = 0
  852. LastAF = 0
  853. MaxAF = -1000
  854. TotalTries = 1
  855. AFStaleFor = 0
  856.  
  857.  
  858. pool = newPool()
  859.  
  860. for i=1,Population do
  861. basic = basicGenome()
  862. addToSpecies(basic)
  863. end
  864.  
  865. initializeRun()
  866. end
  867.  
  868. function clearJoypad()
  869. controller = {}
  870. for b = 1,#ButtonNames do
  871. controller["P1 " .. ButtonNames[b]] = false
  872. end
  873. joypad.set(controller)
  874. end
  875.  
  876. function initializeRun()
  877. savestate.load(Filename);
  878. rightmost = 0
  879. pool.currentFrame = 0
  880. timeout = TimeoutConstant
  881. clearJoypad()
  882.  
  883. local species = pool.species[pool.currentSpecies]
  884. local genome = species.genomes[pool.currentGenome]
  885. generateNetwork(genome)
  886. evaluateCurrent()
  887. end
  888.  
  889. function evaluateCurrent()
  890. local species = pool.species[pool.currentSpecies]
  891. local genome = species.genomes[pool.currentGenome]
  892.  
  893. inputs = getInputs()
  894. controller = evaluateNetwork(genome.network, inputs)
  895.  
  896. if controller["P1 Left"] and controller["P1 Right"] then
  897. controller["P1 Left"] = false
  898. controller["P1 Right"] = false
  899. end
  900. if controller["P1 Up"] and controller["P1 Down"] then
  901. controller["P1 Up"] = false
  902. controller["P1 Down"] = false
  903. end
  904.  
  905. joypad.set(controller)
  906. end
  907.  
  908. if pool == nil then
  909. initializePool()
  910. end
  911.  
  912.  
  913. function nextGenome()
  914. pool.currentGenome = pool.currentGenome + 1
  915. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  916. pool.currentGenome = 1
  917. pool.currentSpecies = pool.currentSpecies+1
  918. if pool.currentSpecies > #pool.species then
  919. newGeneration()
  920. pool.currentSpecies = 1
  921. end
  922. end
  923. end
  924.  
  925. function fitnessAlreadyMeasured()
  926. local species = pool.species[pool.currentSpecies]
  927. local genome = species.genomes[pool.currentGenome]
  928. return genome.fitness ~= 0
  929. end
  930.  
  931. function displayGenome(genome)
  932. local network = genome.network
  933. local cells = {}
  934. local i = 1
  935. local cell = {}
  936. for dy=-BoxRadius,BoxRadius do
  937. for dx=-BoxRadius,BoxRadius do
  938. cell = {}
  939. cell.x = 50+5*dx
  940. cell.y = 70+5*dy
  941. cell.value = network.neurons[i].value
  942. cells[i] = cell
  943. i = i + 1
  944. end
  945. end
  946. local biasCell = {}
  947. biasCell.x = 80
  948. biasCell.y = 110
  949. biasCell.value = network.neurons[Inputs].value
  950. cells[Inputs] = biasCell
  951.  
  952. for o = 1,Outputs do
  953. cell = {}
  954. cell.x = 220
  955. cell.y = 30 + 8 * o
  956. cell.value = network.neurons[MaxNodes + o].value
  957. cells[MaxNodes+o] = cell
  958. local color
  959. if cell.value > 0 then
  960. color = 0xFF0000FF
  961. else
  962. color = 0xFF000000
  963. end
  964. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9) -- drawButtonNames
  965. end
  966.  
  967. for n,neuron in pairs(network.neurons) do
  968. cell = {}
  969. if n > Inputs and n <= MaxNodes then
  970. cell.x = 140
  971. cell.y = 40
  972. cell.value = neuron.value
  973. cells[n] = cell
  974. end
  975. end
  976.  
  977. for n=1,4 do
  978. for _,gene in pairs(genome.genes) do
  979. if gene.enabled then
  980. local c1 = cells[gene.into]
  981. local c2 = cells[gene.out]
  982. if gene.into > Inputs and gene.into <= MaxNodes then
  983. c1.x = 0.75*c1.x + 0.25*c2.x
  984. if c1.x >= c2.x then
  985. c1.x = c1.x - 40
  986. end
  987. if c1.x < 90 then
  988. c1.x = 90
  989. end
  990.  
  991. if c1.x > 220 then
  992. c1.x = 220
  993. end
  994. c1.y = 0.75*c1.y + 0.25*c2.y
  995.  
  996. end
  997. if gene.out > Inputs and gene.out <= MaxNodes then
  998. c2.x = 0.25*c1.x + 0.75*c2.x
  999. if c1.x >= c2.x then
  1000. c2.x = c2.x + 40
  1001. end
  1002. if c2.x < 90 then
  1003. c2.x = 90
  1004. end
  1005. if c2.x > 220 then
  1006. c2.x = 220
  1007. end
  1008. c2.y = 0.25*c1.y + 0.75*c2.y
  1009. end
  1010. end
  1011. end
  1012. end
  1013.  
  1014. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,80+BoxRadius*5+2,0xFF000000, 0x80808080) --edit 50 70
  1015. for n,cell in pairs(cells) do
  1016. if n > Inputs or cell.value ~= 0 then
  1017. local color = math.floor((cell.value+1)/2*256)
  1018. if color > 255 then color = 255 end
  1019. if color < 0 then color = 0 end
  1020. local opacity = 0xFF000000
  1021. if cell.value == 0 then
  1022. opacity = 0x50000000
  1023. end
  1024. color = opacity + color*0x10000 + color*0x100 + color
  1025. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) -- cells, nodes
  1026. end
  1027. end
  1028. for _,gene in pairs(genome.genes) do
  1029. if gene.enabled then
  1030. local c1 = cells[gene.into]
  1031. local c2 = cells[gene.out]
  1032. local opacity = 0xA0000000
  1033. if c1.value == 0 then
  1034. opacity = 0x20000000
  1035. end
  1036.  
  1037. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  1038. if gene.weight > 0 then
  1039. color = opacity + 0x8000 + 0x10000*color
  1040. else
  1041. color = opacity + 0x800000 + 0x100*color
  1042. end
  1043. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  1044. end
  1045. end
  1046.  
  1047. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) -- draw mario bar
  1048.  
  1049. if forms.ischecked(showMutationRates) then
  1050. local pos = 100
  1051. for mutation,rate in pairs(genome.mutationRates) do
  1052. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1053. pos = pos + 8
  1054. end
  1055. end
  1056. end
  1057.  
  1058. function writeFile(filename)
  1059. local file = io.open(filename, "w")
  1060. file:write(TotalTries .. "\n")
  1061. file:write(UniqueSpeciesId .. "\n")
  1062. file:write(CurrentTotalFitness .. "\n")
  1063. file:write(LastAF .. "\n")
  1064. file:write(MaxAF .. "\n")
  1065. file:write(AFStaleFor .. "\n")
  1066. file:write(FitnessesMeasured .. "\n")
  1067. file:write(CurrentAF .. "\n")
  1068. file:write(pool.generation .. "\n")
  1069. file:write(pool.maxFitness .. "\n")
  1070. file:write(#pool.species .. "\n")
  1071. for n,species in pairs(pool.species) do
  1072. file:write(species.uniqueId .. "\n")
  1073. file:write(species.bornGen .. "\n")
  1074. file:write(species.Age .. "\n")
  1075. file:write(species.AF .. "\n")
  1076. file:write(species.topFitness .. "\n")
  1077. file:write(species.staleness .. "\n")
  1078. file:write(#species.genomes .. "\n")
  1079. for m,genome in pairs(species.genomes) do
  1080. file:write(genome.fitness .. "\n")
  1081. file:write(genome.maxneuron .. "\n")
  1082. for mutation,rate in pairs(genome.mutationRates) do
  1083. file:write(mutation .. "\n")
  1084. file:write(rate .. "\n")
  1085. end
  1086. file:write("done\n")
  1087.  
  1088. file:write(#genome.genes .. "\n")
  1089. for l,gene in pairs(genome.genes) do
  1090. file:write(gene.into .. " ")
  1091. file:write(gene.out .. " ")
  1092. file:write(gene.weight .. " ")
  1093. file:write(gene.innovation .. " ")
  1094. if(gene.enabled) then
  1095. file:write("1\n")
  1096. else
  1097. file:write("0\n")
  1098. end
  1099. end
  1100. end
  1101. end
  1102. file:close()
  1103. end
  1104.  
  1105. function savePool()
  1106. local filename = forms.gettext(saveLoadFile)
  1107. writeFile(filename)
  1108. end
  1109.  
  1110. function loadFile(filename)
  1111. local file = io.open(filename, "r")
  1112. pool = newPool()
  1113. TotalTries = file:read("*number")
  1114. UniqueSpeciesId = file:read("*number")
  1115. CurrentTotalFitness = file:read("*number")
  1116. LastAF = file:read("*number")
  1117. MaxAF = file:read("*number")
  1118. AFStaleFor = file:read("*number")
  1119. FitnessesMeasured = file:read("*number")
  1120. CurrentAF = file:read("*number")
  1121. pool.generation = file:read("*number")
  1122. pool.maxFitness = file:read("*number")
  1123. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1124. local numSpecies = file:read("*number")
  1125. for s=1,numSpecies do
  1126. local species = newSpecies()
  1127. table.insert(pool.species, species)
  1128. species.uniqueId = file:read("*number")
  1129. species.bornGen = file:read("*number")
  1130. species.Age = file:read("*number")
  1131. species.AF = file:read("*number")
  1132. species.topFitness = file:read("*number")
  1133. species.staleness = file:read("*number")
  1134. local numGenomes = file:read("*number")
  1135. for g=1,numGenomes do
  1136. local genome = newGenome()
  1137. table.insert(species.genomes, genome)
  1138. genome.fitness = file:read("*number")
  1139. genome.maxneuron = file:read("*number")
  1140. local line = file:read("*line")
  1141. while line ~= "done" do
  1142. genome.mutationRates[line] = file:read("*number")
  1143. line = file:read("*line")
  1144. end
  1145. local numGenes = file:read("*number")
  1146. for n=1,numGenes do
  1147. local gene = newGene()
  1148. table.insert(genome.genes, gene)
  1149. local enabled
  1150. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1151. if enabled == 0 then
  1152. gene.enabled = false
  1153. else
  1154. gene.enabled = true
  1155. end
  1156.  
  1157. end
  1158. end
  1159. end
  1160. file:close()
  1161.  
  1162. while fitnessAlreadyMeasured() do
  1163. nextGenome()
  1164. end
  1165. initializeRun()
  1166. pool.currentFrame = pool.currentFrame + 1
  1167. end
  1168.  
  1169. function loadPool()
  1170. local filename = forms.gettext(saveLoadFile)
  1171. loadFile(filename)
  1172. end
  1173.  
  1174. function playTop()
  1175. local maxfitness = 0
  1176. local maxs, maxg
  1177. for s,species in pairs(pool.species) do
  1178. for g,genome in pairs(species.genomes) do
  1179. if genome.fitness > maxfitness then
  1180. maxfitness = genome.fitness
  1181. maxs = s
  1182. maxg = g
  1183. end
  1184. end
  1185. end
  1186.  
  1187. pool.currentSpecies = maxs
  1188. pool.currentGenome = maxg
  1189. pool.maxFitness = maxfitness
  1190. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1191. initializeRun()
  1192. pool.currentFrame = pool.currentFrame + 1
  1193. return
  1194. end
  1195.  
  1196. function onExit()
  1197. forms.destroy(form)
  1198. end
  1199.  
  1200. writeFile("temp.pool")
  1201.  
  1202. event.onexit(onExit)
  1203.  
  1204. form = forms.newform(2*150, 2*260, "Fitness")
  1205. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 2*5, 2*8)
  1206. showNetwork = forms.checkbox(form, "Show Map", 2*5, 2*30)
  1207. showMutationRates = forms.checkbox(form, "Show M-Rates", 2*5, 2*52)
  1208. restartButton = forms.button(form, "Restart", initializePool, 2*5, 2*77)
  1209. saveButton = forms.button(form, "Save", savePool, 2*5, 2*102, 100, 50)
  1210. loadButton = forms.button(form, "Load", loadPool, 2*80, 2*102, 100, 50)
  1211. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 2*25, nil, 2*5, 2*148)
  1212. saveLoadLabel = forms.label(form, "Save/Load:", 2*5, 2*129)
  1213. playTopButton = forms.button(form, "Play Top", playTop, 2*5, 2*170)
  1214. hideBanner = forms.checkbox(form, "Hide Banner", 2*5, 2*190)
  1215.  
  1216.  
  1217. while true do
  1218. local backgroundColor = 0xD0FFFFFF
  1219. if not forms.ischecked(hideBanner) then
  1220. gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1221. end
  1222.  
  1223. local species = pool.species[pool.currentSpecies]
  1224. local genome = species.genomes[pool.currentGenome]
  1225.  
  1226. if forms.ischecked(showNetwork) then
  1227. displayGenome(genome)
  1228. end
  1229.  
  1230. if pool.currentFrame%5 == 0 then
  1231. evaluateCurrent()
  1232. end
  1233.  
  1234. joypad.set(controller)
  1235.  
  1236. getPositions()
  1237. if marioX > rightmost then
  1238. rightmost = marioX
  1239. timeout = TimeoutConstant
  1240. end
  1241.  
  1242. timeout = timeout - 1
  1243.  
  1244.  
  1245. local timeoutBonus = pool.currentFrame / 4
  1246. if timeout + timeoutBonus <= 0 then
  1247. local fitness = rightmost - pool.currentFrame / 2
  1248. if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then -- makes the level
  1249. fitness = fitness + 1000 -- gets 1000 fitness bonus
  1250. end
  1251. if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1252. fitness = fitness + 1000
  1253. end
  1254. if fitness == 0 then
  1255. fitness = -1
  1256. end
  1257. genome.fitness = fitness
  1258.  
  1259. if fitness > pool.maxFitness then
  1260. pool.maxFitness = fitness
  1261. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1262. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1263. end
  1264.  
  1265. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1266. CurrentTotalFitness = CurrentTotalFitness + fitness
  1267. FitnessesMeasured = FitnessesMeasured + 1
  1268. TotalTries = TotalTries + 1
  1269. pool.currentSpecies = 1
  1270. pool.currentGenome = 1
  1271. while fitnessAlreadyMeasured() do
  1272. nextGenome()
  1273. end
  1274. initializeRun()
  1275. -- savestate.loadslot(1) -- edit line
  1276. end
  1277.  
  1278. local measured = 0
  1279. local total = 0
  1280. for _,species in pairs(pool.species) do
  1281. for _,genome in pairs(species.genomes) do
  1282. total = total + 1
  1283. if genome.fitness ~= 0 then
  1284. measured = measured + 1
  1285. end
  1286. end
  1287. end
  1288. CurrentAF = CurrentTotalFitness / FitnessesMeasured
  1289. if FitnessesMeasured == 0 then
  1290. CurrentAF = 0
  1291. end
  1292. if not forms.ischecked(hideBanner) then
  1293. gui.drawText(0, 0, "G:" .. pool.generation .. "-" .. math.floor(measured/total*100) .. "%" .. "AF:" .. math.floor(CurrentAF) .. "/" .. math.floor(LastAF) .. "/" .. math.floor(MaxAF).. "TT:" .. TotalTries .. "AFS:" .. AFStaleFor, 0xFF000000, 11)
  1294. gui.drawText(0, 12,"ID:" .. species.uniqueId .. "a:" .. species.Age .. "F:" .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3) .. "/" .. math.floor(pool.maxFitness) .. " s:" .. pool.currentSpecies .. "-" .. pool.currentGenome .. "try:" .. FitnessesMeasured .. "ss:" .. species.staleness, 0xFF000000, 11)
  1295.  
  1296.  
  1297.  
  1298.  
  1299. end
  1300.  
  1301. pool.currentFrame = pool.currentFrame + 1
  1302.  
  1303. emu.frameadvance();
  1304. end
  1305.  
  1306. -- math.floor (totalAverageFitness() / (#pool.species/total) ),
  1307.  
  1308. -- (species.averageFitness / sum * Population)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement