Advertisement
Guest User

Untitled

a guest
Aug 19th, 2015
225
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 31.32 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8. Filename = "DP1.state"
  9. ButtonNames = {
  10. "A",
  11. "B",
  12. "X",
  13. "Y",
  14. "Up",
  15. "Down",
  16. "Left",
  17. "Right",
  18. }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20. Filename = "SMB1-1.state"
  21. ButtonNames = {
  22. "A",
  23. "B",
  24. "Up",
  25. "Down",
  26. "Left",
  27. "Right",
  28. }
  29. end
  30.  
  31.  
  32.  
  33.  
  34. Population = 500 -- edit 300
  35. TimeoutConstant = 100 --edit 20
  36.  
  37.  
  38. TotalTries = 1
  39. UniqueSpeciesId = 1 -- shows as 'us', counts up unique ids and stores them to species,
  40. CurrentTotalFitness = 0
  41. CurrentAF = 0
  42. LastAF = 0
  43. MaxAF = -1000
  44. AFStaleFor = 0
  45. FitnessesMeasured = 0
  46.  
  47.  
  48. BoxRadius = 6
  49. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  50.  
  51. Inputs = InputSize
  52. Outputs = #ButtonNames
  53.  
  54.  
  55. DeltaDisjoint = 2.0
  56. DeltaWeights = 0.4
  57. DeltaThreshold = 1.0
  58.  
  59. StaleSpecies = 15
  60.  
  61. MutateConnectionsChance = 0.25
  62. PerturbChance = 0.90
  63. CrossoverChance = 0.75
  64. LinkMutationChance = 2.0
  65. NodeMutationChance = 0.50
  66. BiasMutationChance = 0.40
  67. StepSize = 0.1
  68. DisableMutationChance = 0.4
  69. EnableMutationChance = 0.2
  70.  
  71.  
  72.  
  73. MaxNodes = 1000000
  74.  
  75. function getPositions()
  76. if gameinfo.getromname() == "Super Mario World (USA)" then
  77. marioX = memory.read_s16_le(0x94)
  78. marioY = memory.read_s16_le(0x96)
  79.  
  80. local layer1x = memory.read_s16_le(0x1A);
  81. local layer1y = memory.read_s16_le(0x1C);
  82.  
  83. screenX = marioX-layer1x
  84. screenY = marioY-layer1y
  85. elseif gameinfo.getromname() == "Super Mario Bros." then
  86. marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  87. marioY = memory.readbyte(0x03B8)+16
  88.  
  89. screenX = memory.readbyte(0x03AD)
  90. screenY = memory.readbyte(0x03B8)
  91. end
  92. end
  93.  
  94. function getTile(dx, dy)
  95. if gameinfo.getromname() == "Super Mario World (USA)" then
  96. x = math.floor((marioX+dx+8)/16)
  97. y = math.floor((marioY+dy)/16)
  98.  
  99. return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  100. elseif gameinfo.getromname() == "Super Mario Bros." then
  101. local x = marioX + dx + 8
  102. local y = marioY + dy - 16
  103. local page = math.floor(x/256)%2
  104.  
  105. local subx = math.floor((x%256)/16)
  106. local suby = math.floor((y - 32)/16)
  107. local addr = 0x500 + page*13*16+suby*16+subx
  108.  
  109. if suby >= 13 or suby < 0 then
  110. return 0
  111. end
  112.  
  113. if memory.readbyte(addr) ~= 0 then
  114. return 1
  115. else
  116. return 0
  117. end
  118. end
  119. end
  120.  
  121. function getSprites()
  122. if gameinfo.getromname() == "Super Mario World (USA)" then
  123. local sprites = {}
  124. for slot=0,11 do
  125. local status = memory.readbyte(0x14C8+slot)
  126. if status ~= 0 then
  127. spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  128. spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  129. sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  130. end
  131. end
  132.  
  133. return sprites
  134. elseif gameinfo.getromname() == "Super Mario Bros." then
  135. local sprites = {}
  136. for slot=0,4 do
  137. local enemy = memory.readbyte(0xF+slot)
  138. if enemy ~= 0 then
  139. local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  140. local ey = memory.readbyte(0xCF + slot)+24
  141. sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  142. end
  143. end
  144.  
  145. return sprites
  146. end
  147. end
  148.  
  149. function getExtendedSprites()
  150. if gameinfo.getromname() == "Super Mario World (USA)" then
  151. local extended = {}
  152. for slot=0,11 do
  153. local number = memory.readbyte(0x170B+slot)
  154. if number ~= 0 then
  155. spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  156. spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  157. extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  158. end
  159. end
  160.  
  161. return extended
  162. elseif gameinfo.getromname() == "Super Mario Bros." then
  163. return {}
  164. end
  165. end
  166.  
  167. function getInputs()
  168. getPositions()
  169.  
  170. sprites = getSprites()
  171. extended = getExtendedSprites()
  172.  
  173. local inputs = {}
  174.  
  175. for dy=-BoxRadius*16,BoxRadius*16,16 do
  176. for dx=-BoxRadius*16,BoxRadius*16,16 do
  177. inputs[#inputs+1] = 0
  178.  
  179. tile = getTile(dx, dy)
  180. if tile == 1 and marioY+dy < 0x1B0 then
  181. inputs[#inputs] = 1
  182. end
  183.  
  184. for i = 1,#sprites do
  185. distx = math.abs(sprites[i]["x"] - (marioX+dx))
  186. disty = math.abs(sprites[i]["y"] - (marioY+dy))
  187. if distx <= 8 and disty <= 8 then
  188. inputs[#inputs] = -1
  189. end
  190. end
  191.  
  192. for i = 1,#extended do
  193. distx = math.abs(extended[i]["x"] - (marioX+dx))
  194. disty = math.abs(extended[i]["y"] - (marioY+dy))
  195. if distx < 8 and disty < 8 then
  196. inputs[#inputs] = -1
  197. end
  198. end
  199. end
  200. end
  201.  
  202. --mariovx = memory.read_s8(0x7B)
  203. --mariovy = memory.read_s8(0x7D)
  204.  
  205. return inputs
  206. end
  207.  
  208. function sigmoid(x)
  209. return 2/(1+math.exp(-4.9*x))-1
  210. end
  211.  
  212. function newInnovation()
  213. pool.innovation = pool.innovation + 1
  214. return pool.innovation
  215. end
  216.  
  217. function newPool()
  218. local pool = {}
  219. pool.species = {}
  220. pool.generation = 0
  221. pool.innovation = Outputs
  222. pool.currentSpecies = 1
  223. pool.currentGenome = 1
  224. pool.currentFrame = 0
  225. pool.maxFitness = 0
  226.  
  227. return pool
  228. end
  229.  
  230. function newSpecies()
  231. local species = {}
  232. species.topFitness = 0
  233. species.staleness = 0
  234. species.genomes = {}
  235. species.averageFitness = 0
  236. species.uniqueId = UniqueSpeciesId
  237. species.bornGen = pool.generation
  238. species.Age = 0
  239. species.AF = 0;
  240. return species
  241. end
  242.  
  243. function newGenome()
  244. local genome = {}
  245. genome.genes = {}
  246. genome.fitness = 0
  247. genome.adjustedFitness = 0
  248. genome.network = {}
  249. genome.maxneuron = 0
  250. genome.globalRank = 0
  251. genome.mutationRates = {}
  252. genome.mutationRates["connections"] = MutateConnectionsChance
  253. genome.mutationRates["link"] = LinkMutationChance
  254. genome.mutationRates["bias"] = BiasMutationChance
  255. genome.mutationRates["node"] = NodeMutationChance
  256. genome.mutationRates["enable"] = EnableMutationChance
  257. genome.mutationRates["disable"] = DisableMutationChance
  258. genome.mutationRates["step"] = StepSize
  259.  
  260. return genome
  261. end
  262.  
  263. function copyGenome(genome)
  264. local genome2 = newGenome()
  265. for g=1,#genome.genes do
  266. table.insert(genome2.genes, copyGene(genome.genes[g]))
  267. end
  268. genome2.maxneuron = genome.maxneuron
  269. genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  270. genome2.mutationRates["link"] = genome.mutationRates["link"]
  271. genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  272. genome2.mutationRates["node"] = genome.mutationRates["node"]
  273. genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  274. genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  275.  
  276. return genome2
  277. end
  278.  
  279. function basicGenome()
  280. local genome = newGenome()
  281. local innovation = 1
  282.  
  283. genome.maxneuron = Inputs
  284. mutate(genome)
  285.  
  286. return genome
  287. end
  288.  
  289. function newGene()
  290. local gene = {}
  291. gene.into = 0
  292. gene.out = 0
  293. gene.weight = 0.0
  294. gene.enabled = true
  295. gene.innovation = 0
  296.  
  297. return gene
  298. end
  299.  
  300. function copyGene(gene)
  301. local gene2 = newGene()
  302. gene2.into = gene.into
  303. gene2.out = gene.out
  304. gene2.weight = gene.weight
  305. gene2.enabled = gene.enabled
  306. gene2.innovation = gene.innovation
  307.  
  308. return gene2
  309. end
  310.  
  311. function newNeuron()
  312. local neuron = {}
  313. neuron.incoming = {}
  314. neuron.value = 0.0
  315.  
  316. return neuron
  317. end
  318.  
  319. function generateNetwork(genome)
  320. local network = {}
  321. network.neurons = {}
  322.  
  323. for i=1,Inputs do
  324. network.neurons[i] = newNeuron()
  325. end
  326.  
  327. for o=1,Outputs do
  328. network.neurons[MaxNodes+o] = newNeuron()
  329. end
  330.  
  331. table.sort(genome.genes, function (a,b)
  332. return (a.out < b.out)
  333. end)
  334. for i=1,#genome.genes do
  335. local gene = genome.genes[i]
  336. if gene.enabled then
  337. if network.neurons[gene.out] == nil then
  338. network.neurons[gene.out] = newNeuron()
  339. end
  340. local neuron = network.neurons[gene.out]
  341. table.insert(neuron.incoming, gene)
  342. if network.neurons[gene.into] == nil then
  343. network.neurons[gene.into] = newNeuron()
  344. end
  345. end
  346. end
  347.  
  348. genome.network = network
  349. end
  350.  
  351. function evaluateNetwork(network, inputs)
  352. -- table.insert(inputs, 1)
  353. if #inputs ~= Inputs then
  354. console.writeline("#inputs:" .. #inputs .. "Inputs:" .. Inputs)
  355. console.writeline("Incorrect number of neural network inputs.")
  356. return {}
  357. end
  358.  
  359. for i=1,Inputs do
  360. network.neurons[i].value = inputs[i]
  361. end
  362.  
  363. for _,neuron in pairs(network.neurons) do
  364. local sum = 0
  365. for j = 1,#neuron.incoming do
  366. local incoming = neuron.incoming[j]
  367. local other = network.neurons[incoming.into]
  368. sum = sum + incoming.weight * other.value
  369. end
  370.  
  371. if #neuron.incoming > 0 then
  372. neuron.value = sigmoid(sum)
  373. end
  374. end
  375.  
  376. local outputs = {}
  377. for o=1,Outputs do
  378. local button = "P1 " .. ButtonNames[o]
  379. if network.neurons[MaxNodes+o].value > 0 then
  380. outputs[button] = true
  381. else
  382. outputs[button] = false
  383. end
  384. end
  385.  
  386. return outputs
  387. end
  388.  
  389. function crossover(g1, g2)
  390. -- Make sure g1 is the higher fitness genome
  391. if g2.fitness > g1.fitness then
  392. tempg = g1
  393. g1 = g2
  394. g2 = tempg
  395. end
  396.  
  397. local child = newGenome()
  398.  
  399. local innovations2 = {}
  400. for i=1,#g2.genes do
  401. local gene = g2.genes[i]
  402. innovations2[gene.innovation] = gene
  403. end
  404.  
  405. for i=1,#g1.genes do
  406. local gene1 = g1.genes[i]
  407. local gene2 = innovations2[gene1.innovation]
  408. if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  409. table.insert(child.genes, copyGene(gene2))
  410. else
  411. table.insert(child.genes, copyGene(gene1))
  412. end
  413. end
  414.  
  415. child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  416.  
  417. for mutation,rate in pairs(g1.mutationRates) do
  418. child.mutationRates[mutation] = rate
  419. end
  420.  
  421. return child
  422. end
  423.  
  424. function randomNeuron(genes, nonInput)
  425. local neurons = {}
  426. if not nonInput then
  427. for i=1,Inputs do
  428. neurons[i] = true
  429. end
  430. end
  431. for o=1,Outputs do
  432. neurons[MaxNodes+o] = true
  433. end
  434. for i=1,#genes do
  435. if (not nonInput) or genes[i].into > Inputs then
  436. neurons[genes[i].into] = true
  437. end
  438. if (not nonInput) or genes[i].out > Inputs then
  439. neurons[genes[i].out] = true
  440. end
  441. end
  442.  
  443. local count = 0
  444. for _,_ in pairs(neurons) do
  445. count = count + 1
  446. end
  447. local n = math.random(1, count)
  448.  
  449. for k,v in pairs(neurons) do
  450. n = n-1
  451. if n == 0 then
  452. return k
  453. end
  454. end
  455.  
  456. return 0
  457. end
  458.  
  459. function containsLink(genes, link)
  460. for i=1,#genes do
  461. local gene = genes[i]
  462. if gene.into == link.into and gene.out == link.out then
  463. return true
  464. end
  465. end
  466. end
  467.  
  468. function pointMutate(genome)
  469. local step = genome.mutationRates["step"]
  470.  
  471. for i=1,#genome.genes do
  472. local gene = genome.genes[i]
  473. if math.random() < PerturbChance then
  474. gene.weight = gene.weight + math.random() * step*2 - step
  475. else
  476. gene.weight = math.random()*4-2
  477. end
  478. end
  479. end
  480.  
  481. function linkMutate(genome, forceBias)
  482. local neuron1 = randomNeuron(genome.genes, false)
  483. local neuron2 = randomNeuron(genome.genes, true)
  484.  
  485. local newLink = newGene()
  486. if neuron1 <= Inputs and neuron2 <= Inputs then
  487. --Both input nodes
  488. return
  489. end
  490. if neuron2 <= Inputs then
  491. -- Swap output and input
  492. local temp = neuron1
  493. neuron1 = neuron2
  494. neuron2 = temp
  495. end
  496.  
  497. newLink.into = neuron1
  498. newLink.out = neuron2
  499. if forceBias then
  500. newLink.into = Inputs
  501. end
  502.  
  503. if containsLink(genome.genes, newLink) then
  504. return
  505. end
  506. newLink.innovation = newInnovation()
  507. newLink.weight = math.random()*4-2
  508.  
  509. table.insert(genome.genes, newLink)
  510. end
  511.  
  512. function nodeMutate(genome)
  513. if #genome.genes == 0 then
  514. return
  515. end
  516.  
  517. genome.maxneuron = genome.maxneuron + 1
  518.  
  519. local gene = genome.genes[math.random(1,#genome.genes)]
  520. if not gene.enabled then
  521. return
  522. end
  523. gene.enabled = false
  524.  
  525. local gene1 = copyGene(gene)
  526. gene1.out = genome.maxneuron
  527. gene1.weight = 1.0
  528. gene1.innovation = newInnovation()
  529. gene1.enabled = true
  530. table.insert(genome.genes, gene1)
  531.  
  532. local gene2 = copyGene(gene)
  533. gene2.into = genome.maxneuron
  534. gene2.innovation = newInnovation()
  535. gene2.enabled = true
  536. table.insert(genome.genes, gene2)
  537. end
  538.  
  539. function enableDisableMutate(genome, enable)
  540. local candidates = {}
  541. for _,gene in pairs(genome.genes) do
  542. if gene.enabled == not enable then
  543. table.insert(candidates, gene)
  544. end
  545. end
  546.  
  547. if #candidates == 0 then
  548. return
  549. end
  550.  
  551. local gene = candidates[math.random(1,#candidates)]
  552. gene.enabled = not gene.enabled
  553. end
  554.  
  555. function mutate(genome)
  556. for mutation,rate in pairs(genome.mutationRates) do
  557. if math.random(1,2) == 1 then
  558. genome.mutationRates[mutation] = 0.95*rate
  559. else
  560. genome.mutationRates[mutation] = 1.05263*rate
  561. end
  562. end
  563.  
  564. if math.random() < genome.mutationRates["connections"] then
  565. pointMutate(genome)
  566. end
  567.  
  568. local p = genome.mutationRates["link"]
  569. while p > 0 do
  570. if math.random() < p then
  571. linkMutate(genome, false)
  572. end
  573. p = p - 1
  574. end
  575.  
  576. p = genome.mutationRates["bias"]
  577. while p > 0 do
  578. if math.random() < p then
  579. linkMutate(genome, true)
  580. end
  581. p = p - 1
  582. end
  583.  
  584. p = genome.mutationRates["node"]
  585. while p > 0 do
  586. if math.random() < p then
  587. nodeMutate(genome)
  588. end
  589. p = p - 1
  590. end
  591.  
  592. p = genome.mutationRates["enable"]
  593. while p > 0 do
  594. if math.random() < p then
  595. enableDisableMutate(genome, true)
  596. end
  597. p = p - 1
  598. end
  599.  
  600. p = genome.mutationRates["disable"]
  601. while p > 0 do
  602. if math.random() < p then
  603. enableDisableMutate(genome, false)
  604. end
  605. p = p - 1
  606. end
  607. end
  608.  
  609. function disjoint(genes1, genes2)
  610. local i1 = {}
  611. for i = 1,#genes1 do
  612. local gene = genes1[i]
  613. i1[gene.innovation] = true
  614. end
  615.  
  616. local i2 = {}
  617. for i = 1,#genes2 do
  618. local gene = genes2[i]
  619. i2[gene.innovation] = true
  620. end
  621.  
  622. local disjointGenes = 0
  623. for i = 1,#genes1 do
  624. local gene = genes1[i]
  625. if not i2[gene.innovation] then
  626. disjointGenes = disjointGenes+1
  627. end
  628. end
  629.  
  630. for i = 1,#genes2 do
  631. local gene = genes2[i]
  632. if not i1[gene.innovation] then
  633. disjointGenes = disjointGenes+1
  634. end
  635. end
  636.  
  637. local n = math.max(#genes1, #genes2)
  638.  
  639. return disjointGenes / n
  640. end
  641.  
  642. function weights(genes1, genes2)
  643. local i2 = {}
  644. for i = 1,#genes2 do
  645. local gene = genes2[i]
  646. i2[gene.innovation] = gene
  647. end
  648.  
  649. local sum = 0
  650. local coincident = 0
  651. for i = 1,#genes1 do
  652. local gene = genes1[i]
  653. if i2[gene.innovation] ~= nil then
  654. local gene2 = i2[gene.innovation]
  655. sum = sum + math.abs(gene.weight - gene2.weight)
  656. coincident = coincident + 1
  657. end
  658. end
  659.  
  660. return sum / coincident
  661. end
  662.  
  663. function sameSpecies(genome1, genome2)
  664. local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  665. local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  666. return dd + dw < DeltaThreshold
  667. end
  668.  
  669. function rankGlobally()
  670. local global = {}
  671. for s = 1,#pool.species do
  672. local species = pool.species[s]
  673. for g = 1,#species.genomes do
  674. table.insert(global, species.genomes[g])
  675. end
  676. end
  677. table.sort(global, function (a,b)
  678. return (a.fitness < b.fitness)
  679. end)
  680.  
  681. for g=1,#global do
  682. global[g].globalRank = g
  683. end
  684. end
  685.  
  686. function calculateAverageFitness(species)
  687. local total = 0
  688.  
  689. for g=1,#species.genomes do
  690. local genome = species.genomes[g]
  691. total = total + genome.globalRank
  692. end
  693.  
  694. species.averageFitness = total / #species.genomes
  695. end
  696.  
  697. function totalAverageFitness()
  698. local total = 0
  699. for s = 1,#pool.species do
  700. local species = pool.species[s]
  701. total = total + species.averageFitness
  702. end
  703. LastSpecies = #pool.species
  704. return total
  705. end
  706.  
  707. function cullSpecies(cutToOne)
  708. for s = 1,#pool.species do
  709. local species = pool.species[s]
  710.  
  711. table.sort(species.genomes, function (a,b)
  712. return (a.fitness > b.fitness)
  713. end)
  714.  
  715. local remaining = math.ceil(#species.genomes/2)
  716. if cutToOne then
  717. remaining = 1
  718. end
  719. while #species.genomes > remaining do
  720. table.remove(species.genomes)
  721. end
  722. end
  723. end
  724.  
  725. function breedChild(species)
  726. local child = {}
  727. if math.random() < CrossoverChance then
  728. g1 = species.genomes[math.random(1, #species.genomes)]
  729. g2 = species.genomes[math.random(1, #species.genomes)]
  730. child = crossover(g1, g2)
  731. else
  732. g = species.genomes[math.random(1, #species.genomes)]
  733. child = copyGenome(g)
  734. end
  735.  
  736. mutate(child)
  737.  
  738. return child
  739. end
  740.  
  741. function removeStaleSpecies()
  742. local survived = {}
  743.  
  744. for s = 1,#pool.species do
  745. local species = pool.species[s]
  746.  
  747. table.sort(species.genomes, function (a,b)
  748. return (a.fitness > b.fitness)
  749. end)
  750.  
  751. if species.genomes[1].fitness > species.topFitness then
  752. species.topFitness = species.genomes[1].fitness
  753. species.staleness = 0
  754. else
  755. species.staleness = species.staleness + 1
  756. end
  757. if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  758. table.insert(survived, species)
  759. end
  760. end
  761.  
  762. pool.species = survived
  763. end
  764.  
  765. function removeWeakSpecies()
  766. local survived = {}
  767.  
  768. local sum = totalAverageFitness()
  769. for s = 1,#pool.species do
  770. local species = pool.species[s]
  771. breed = math.floor(species.averageFitness / sum * Population)
  772. species.AF = (species.averageFitness / sum * Population)
  773. if breed >= 1 then
  774. table.insert(survived, species)
  775. end
  776. end
  777.  
  778. pool.species = survived
  779. end
  780.  
  781.  
  782. function addToSpecies(child)
  783. local foundSpecies = false
  784. for s=1,#pool.species do
  785. local species = pool.species[s]
  786. if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  787. table.insert(species.genomes, child)
  788. foundSpecies = true
  789. end
  790. end
  791.  
  792. if not foundSpecies then
  793. local childSpecies = newSpecies()
  794. UniqueSpeciesId = UniqueSpeciesId + 1 -- edit line
  795. table.insert(childSpecies.genomes, child)
  796. table.insert(pool.species, childSpecies)
  797. end
  798. end
  799.  
  800. function newGeneration()
  801. LastAF = CurrentTotalFitness / FitnessesMeasured
  802. if LastAF > MaxAF then
  803. MaxAF = LastAF
  804. AFStaleFor = 0
  805. else
  806. AFStaleFor = AFStaleFor + 1
  807. end
  808. CurrentTotalFitness = 0
  809. FitnessesMeasured = 0
  810. cullSpecies(false) -- Cull the bottom half of each species
  811. rankGlobally()
  812. removeStaleSpecies()
  813. rankGlobally()
  814. pool.generation = pool.generation + 1
  815. for s = 1,#pool.species do
  816. local species = pool.species[s]
  817. calculateAverageFitness(species)
  818. species.Age = pool.generation - species.bornGen
  819. end
  820. removeWeakSpecies()
  821. local sum = totalAverageFitness()
  822. local children = {}
  823. for s = 1,#pool.species do
  824. local species = pool.species[s]
  825. breed = math.floor(species.averageFitness / sum * Population) - 1
  826. for i=1,breed do
  827. table.insert(children, breedChild(species))
  828. end
  829. end
  830. cullSpecies(true) -- Cull all but the top member of each species
  831. while #children + #pool.species < Population do
  832. local species = pool.species[math.random(1, #pool.species)]
  833. table.insert(children, breedChild(species))
  834. end
  835.  
  836.  
  837.  
  838. for c=1,#children do
  839. local child = children[c]
  840. addToSpecies(child)
  841. end
  842.  
  843.  
  844.  
  845.  
  846. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  847. end
  848.  
  849. function initializePool()
  850. UniqueSpeciesId = 1
  851. CurrentTotalFitness = 0
  852. FitnessesMeasured = 0
  853. LastAF = 0
  854. MaxAF = -1000
  855. TotalTries = 1
  856. AFStaleFor = 0
  857.  
  858.  
  859. pool = newPool()
  860.  
  861. for i=1,Population do
  862. basic = basicGenome()
  863. addToSpecies(basic)
  864. end
  865.  
  866. initializeRun()
  867. end
  868.  
  869. function clearJoypad()
  870. controller = {}
  871. for b = 1,#ButtonNames do
  872. controller["P1 " .. ButtonNames[b]] = false
  873. end
  874. joypad.set(controller)
  875. end
  876.  
  877. function initializeRun()
  878. savestate.load(Filename);
  879. rightmost = 0
  880. pool.currentFrame = 0
  881. timeout = TimeoutConstant
  882. clearJoypad()
  883.  
  884. local species = pool.species[pool.currentSpecies]
  885. local genome = species.genomes[pool.currentGenome]
  886. generateNetwork(genome)
  887. evaluateCurrent()
  888. end
  889.  
  890. function evaluateCurrent()
  891. local species = pool.species[pool.currentSpecies]
  892. local genome = species.genomes[pool.currentGenome]
  893.  
  894. inputs = getInputs()
  895. controller = evaluateNetwork(genome.network, inputs)
  896.  
  897. if controller["P1 Left"] and controller["P1 Right"] then
  898. controller["P1 Left"] = false
  899. controller["P1 Right"] = false
  900. end
  901. if controller["P1 Up"] and controller["P1 Down"] then
  902. controller["P1 Up"] = false
  903. controller["P1 Down"] = false
  904. end
  905.  
  906. joypad.set(controller)
  907. end
  908.  
  909. if pool == nil then
  910. initializePool()
  911. end
  912.  
  913.  
  914. function nextGenome()
  915. pool.currentGenome = pool.currentGenome + 1
  916. if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  917. pool.currentGenome = 1
  918. pool.currentSpecies = pool.currentSpecies+1
  919. if pool.currentSpecies > #pool.species then
  920. newGeneration()
  921. pool.currentSpecies = 1
  922. end
  923. end
  924. end
  925.  
  926. function fitnessAlreadyMeasured()
  927. local species = pool.species[pool.currentSpecies]
  928. local genome = species.genomes[pool.currentGenome]
  929. return genome.fitness ~= 0
  930. end
  931.  
  932. function displayGenome(genome)
  933. local network = genome.network
  934. local cells = {}
  935. local i = 1
  936. local cell = {}
  937. for dy=-BoxRadius,BoxRadius do
  938. for dx=-BoxRadius,BoxRadius do
  939. cell = {}
  940. cell.x = 50+5*dx
  941. cell.y = 70+5*dy
  942. cell.value = network.neurons[i].value
  943. cells[i] = cell
  944. i = i + 1
  945. end
  946. end
  947. -- local biasCell = {}
  948. -- biasCell.x = 80
  949. -- biasCell.y = 110
  950. -- biasCell.value = network.neurons[Inputs].value
  951. -- cells[Inputs] = biasCell
  952.  
  953. for o = 1,Outputs do
  954. cell = {}
  955. cell.x = 220
  956. cell.y = 30 + 8 * o
  957. cell.value = network.neurons[MaxNodes + o].value
  958. cells[MaxNodes+o] = cell
  959. local color
  960. if cell.value > 0 then
  961. color = 0xFF0000FF
  962. else
  963. color = 0xFF000000
  964. end
  965. gui.drawText(223, 24+8*o, ButtonNames[o], color, 9) -- drawButtonNames
  966. end
  967.  
  968. for n,neuron in pairs(network.neurons) do
  969. cell = {}
  970. if n > Inputs and n <= MaxNodes then
  971. cell.x = 140
  972. cell.y = 40
  973. cell.value = neuron.value
  974. cells[n] = cell
  975. end
  976. end
  977.  
  978. for n=1,4 do
  979. for _,gene in pairs(genome.genes) do
  980. if gene.enabled then
  981. local c1 = cells[gene.into]
  982. local c2 = cells[gene.out]
  983. if gene.into > Inputs and gene.into <= MaxNodes then
  984. c1.x = 0.75*c1.x + 0.25*c2.x
  985. if c1.x >= c2.x then
  986. c1.x = c1.x - 40
  987. end
  988. if c1.x < 90 then
  989. c1.x = 90
  990. end
  991.  
  992. if c1.x > 220 then
  993. c1.x = 220
  994. end
  995. c1.y = 0.75*c1.y + 0.25*c2.y
  996.  
  997. end
  998. if gene.out > Inputs and gene.out <= MaxNodes then
  999. c2.x = 0.25*c1.x + 0.75*c2.x
  1000. if c1.x >= c2.x then
  1001. c2.x = c2.x + 40
  1002. end
  1003. if c2.x < 90 then
  1004. c2.x = 90
  1005. end
  1006. if c2.x > 220 then
  1007. c2.x = 220
  1008. end
  1009. c2.y = 0.25*c1.y + 0.75*c2.y
  1010. end
  1011. end
  1012. end
  1013. end
  1014.  
  1015. gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080) --edit 50 70
  1016. for n,cell in pairs(cells) do
  1017. if n > Inputs or cell.value ~= 0 then
  1018. local color = math.floor((cell.value+1)/2*256)
  1019. if color > 255 then color = 255 end
  1020. if color < 0 then color = 0 end
  1021. local opacity = 0xFF000000
  1022. if cell.value == 0 then
  1023. opacity = 0x50000000
  1024. end
  1025. color = opacity + color*0x10000 + color*0x100 + color
  1026. gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) -- cells, nodes
  1027. end
  1028. end
  1029. for _,gene in pairs(genome.genes) do
  1030. if gene.enabled then
  1031. local c1 = cells[gene.into]
  1032. local c2 = cells[gene.out]
  1033. local opacity = 0xA0000000
  1034. if c1.value == 0 then
  1035. opacity = 0x20000000
  1036. end
  1037.  
  1038. local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  1039. if gene.weight > 0 then
  1040. color = opacity + 0x8000 + 0x10000*color
  1041. else
  1042. color = opacity + 0x800000 + 0x100*color
  1043. end
  1044. gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  1045. end
  1046. end
  1047.  
  1048. gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) -- draw mario bar
  1049.  
  1050. if forms.ischecked(showMutationRates) then
  1051. local pos = 100
  1052. for mutation,rate in pairs(genome.mutationRates) do
  1053. gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1054. pos = pos + 8
  1055. end
  1056. end
  1057. end
  1058.  
  1059. function writeFile(filename)
  1060. local file = io.open(filename, "w")
  1061. file:write(TotalTries .. "\n")
  1062. file:write(UniqueSpeciesId .. "\n")
  1063. file:write(CurrentTotalFitness .. "\n")
  1064. file:write(LastAF .. "\n")
  1065. file:write(MaxAF .. "\n")
  1066. file:write(AFStaleFor .. "\n")
  1067. file:write(FitnessesMeasured .. "\n")
  1068. file:write(CurrentAF .. "\n")
  1069. file:write(pool.generation .. "\n")
  1070. file:write(pool.maxFitness .. "\n")
  1071. file:write(#pool.species .. "\n")
  1072. for n,species in pairs(pool.species) do
  1073. file:write(species.uniqueId .. "\n")
  1074. file:write(species.bornGen .. "\n")
  1075. file:write(species.Age .. "\n")
  1076. file:write(species.AF .. "\n")
  1077. file:write(species.topFitness .. "\n")
  1078. file:write(species.staleness .. "\n")
  1079. file:write(#species.genomes .. "\n")
  1080. for m,genome in pairs(species.genomes) do
  1081. file:write(genome.fitness .. "\n")
  1082. file:write(genome.maxneuron .. "\n")
  1083. for mutation,rate in pairs(genome.mutationRates) do
  1084. file:write(mutation .. "\n")
  1085. file:write(rate .. "\n")
  1086. end
  1087. file:write("done\n")
  1088.  
  1089. file:write(#genome.genes .. "\n")
  1090. for l,gene in pairs(genome.genes) do
  1091. file:write(gene.into .. " ")
  1092. file:write(gene.out .. " ")
  1093. file:write(gene.weight .. " ")
  1094. file:write(gene.innovation .. " ")
  1095. if(gene.enabled) then
  1096. file:write("1\n")
  1097. else
  1098. file:write("0\n")
  1099. end
  1100. end
  1101. end
  1102. end
  1103. file:close()
  1104. end
  1105.  
  1106. function savePool()
  1107. local filename = forms.gettext(saveLoadFile)
  1108. writeFile(filename)
  1109. end
  1110.  
  1111. function loadFile(filename)
  1112. local file = io.open(filename, "r")
  1113. pool = newPool()
  1114. TotalTries = file:read("*number")
  1115. UniqueSpeciesId = file:read("*number")
  1116. CurrentTotalFitness = file:read("*number")
  1117. LastAF = file:read("*number")
  1118. MaxAF = file:read("*number")
  1119. AFStaleFor = file:read("*number")
  1120. FitnessesMeasured = file:read("*number")
  1121. CurrentAF = file:read("*number")
  1122. pool.generation = file:read("*number")
  1123. pool.maxFitness = file:read("*number")
  1124. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1125. local numSpecies = file:read("*number")
  1126. for s=1,numSpecies do
  1127. local species = newSpecies()
  1128. table.insert(pool.species, species)
  1129. species.uniqueId = file:read("*number")
  1130. species.bornGen = file:read("*number")
  1131. species.Age = file:read("*number")
  1132. species.AF = file:read("*number")
  1133. species.topFitness = file:read("*number")
  1134. species.staleness = file:read("*number")
  1135. local numGenomes = file:read("*number")
  1136. for g=1,numGenomes do
  1137. local genome = newGenome()
  1138. table.insert(species.genomes, genome)
  1139. genome.fitness = file:read("*number")
  1140. genome.maxneuron = file:read("*number")
  1141. local line = file:read("*line")
  1142. while line ~= "done" do
  1143. genome.mutationRates[line] = file:read("*number")
  1144. line = file:read("*line")
  1145. end
  1146. local numGenes = file:read("*number")
  1147. for n=1,numGenes do
  1148. local gene = newGene()
  1149. table.insert(genome.genes, gene)
  1150. local enabled
  1151. gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1152. if enabled == 0 then
  1153. gene.enabled = false
  1154. else
  1155. gene.enabled = true
  1156. end
  1157.  
  1158. end
  1159. end
  1160. end
  1161. file:close()
  1162.  
  1163. while fitnessAlreadyMeasured() do
  1164. nextGenome()
  1165. end
  1166. initializeRun()
  1167. pool.currentFrame = pool.currentFrame + 1
  1168. end
  1169.  
  1170. function loadPool()
  1171. local filename = forms.gettext(saveLoadFile)
  1172. loadFile(filename)
  1173. end
  1174.  
  1175. function playTop()
  1176. local maxfitness = 0
  1177. local maxs, maxg
  1178. for s,species in pairs(pool.species) do
  1179. for g,genome in pairs(species.genomes) do
  1180. if genome.fitness > maxfitness then
  1181. maxfitness = genome.fitness
  1182. maxs = s
  1183. maxg = g
  1184. end
  1185. end
  1186. end
  1187.  
  1188. pool.currentSpecies = maxs
  1189. pool.currentGenome = maxg
  1190. pool.maxFitness = maxfitness
  1191. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1192. initializeRun()
  1193. pool.currentFrame = pool.currentFrame + 1
  1194. return
  1195. end
  1196.  
  1197. function onExit()
  1198. forms.destroy(form)
  1199. end
  1200.  
  1201. writeFile("temp.pool")
  1202.  
  1203. event.onexit(onExit)
  1204.  
  1205. form = forms.newform(2*150, 2*260, "Fitness")
  1206. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 2*5, 2*8)
  1207. showNetwork = forms.checkbox(form, "Show Map", 2*5, 2*30)
  1208. showMutationRates = forms.checkbox(form, "Show M-Rates", 2*5, 2*52)
  1209. restartButton = forms.button(form, "Restart", initializePool, 2*5, 2*77)
  1210. saveButton = forms.button(form, "Save", savePool, 2*5, 2*102, 100, 50)
  1211. loadButton = forms.button(form, "Load", loadPool, 2*80, 2*102, 100, 50)
  1212. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 2*25, nil, 2*5, 2*148)
  1213. saveLoadLabel = forms.label(form, "Save/Load:", 2*5, 2*129)
  1214. playTopButton = forms.button(form, "Play Top", playTop, 2*5, 2*170)
  1215. hideBanner = forms.checkbox(form, "Hide Banner", 2*5, 2*190)
  1216.  
  1217.  
  1218. while true do
  1219. local backgroundColor = 0xD0FFFFFF
  1220. if not forms.ischecked(hideBanner) then
  1221. gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1222. end
  1223.  
  1224. local species = pool.species[pool.currentSpecies]
  1225. local genome = species.genomes[pool.currentGenome]
  1226.  
  1227. if forms.ischecked(showNetwork) then
  1228. displayGenome(genome)
  1229. end
  1230.  
  1231. if pool.currentFrame%5 == 0 then
  1232. evaluateCurrent()
  1233. end
  1234.  
  1235. joypad.set(controller)
  1236.  
  1237. getPositions()
  1238. if marioX > rightmost then
  1239. rightmost = marioX
  1240. timeout = TimeoutConstant
  1241. end
  1242.  
  1243. timeout = timeout - 1
  1244.  
  1245.  
  1246. local timeoutBonus = pool.currentFrame / 4
  1247. if timeout + timeoutBonus <= 0 then
  1248. local fitness = rightmost - pool.currentFrame / 2
  1249. if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then -- makes the level
  1250. fitness = fitness + 1000 -- gets 1000 fitness bonus
  1251. end
  1252. if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1253. fitness = fitness + 1000
  1254. end
  1255. if fitness == 0 then
  1256. fitness = -1
  1257. end
  1258. genome.fitness = fitness
  1259.  
  1260. if fitness > pool.maxFitness then
  1261. pool.maxFitness = fitness
  1262. forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1263. writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1264. end
  1265.  
  1266. console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1267. CurrentTotalFitness = CurrentTotalFitness + fitness
  1268. FitnessesMeasured = FitnessesMeasured + 1
  1269. TotalTries = TotalTries + 1
  1270. pool.currentSpecies = 1
  1271. pool.currentGenome = 1
  1272. while fitnessAlreadyMeasured() do
  1273. nextGenome()
  1274. end
  1275. initializeRun()
  1276. -- savestate.loadslot(1) -- edit line
  1277. end
  1278.  
  1279. local measured = 0
  1280. local total = 0
  1281. for _,species in pairs(pool.species) do
  1282. for _,genome in pairs(species.genomes) do
  1283. total = total + 1
  1284. if genome.fitness ~= 0 then
  1285. measured = measured + 1
  1286. end
  1287. end
  1288. end
  1289. CurrentAF = CurrentTotalFitness / FitnessesMeasured
  1290. if FitnessesMeasured == 0 then
  1291. CurrentAF = 0
  1292. end
  1293. if not forms.ischecked(hideBanner) then
  1294. gui.drawText(0, 0, "G:" .. pool.generation .. "-" .. math.floor(measured/total*100) .. "%" .. "AF:" .. math.floor(CurrentAF) .. "/" .. math.floor(LastAF) .. "/" .. math.floor(MaxAF).. "TT:" .. TotalTries .. "AFS:" .. AFStaleFor, 0xFF000000, 11)
  1295. gui.drawText(0, 12,"ID:" .. species.uniqueId .. "a:" .. species.Age .. "F:" .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3) .. "/" .. math.floor(pool.maxFitness) .. " s:" .. pool.currentSpecies .. "-" .. pool.currentGenome .. "t:" .. FitnessesMeasured .. "ss:" .. species.staleness, 0xFF000000, 11)
  1296.  
  1297.  
  1298.  
  1299.  
  1300. end
  1301.  
  1302. pool.currentFrame = pool.currentFrame + 1
  1303.  
  1304. emu.frameadvance();
  1305. end
  1306.  
  1307. -- math.floor (totalAverageFitness() / (#pool.species/total) ),
  1308.  
  1309. -- (species.averageFitness / sum * Population)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement