Advertisement
Guest User

Untitled

a guest
Jun 15th, 2015
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 23.39 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4.  
  5. if gameinfo.getromname() == "Super Mario World (USA)" then
  6.     Filename = "DP.state"
  7.     ButtonNames = {
  8.         "A",
  9.         "B",
  10.         "X",
  11.         "Y",
  12.         "Up",
  13.         "Down",
  14.         "Left",
  15.         "Right",
  16.     }
  17. elseif gameinfo.getromname() == "Super Mario Bros." then
  18.     Filename = "SMB1-1.state"
  19.     ButtonNames = {
  20.         "A",
  21.         "B",
  22.         "Up",
  23.         "Down",
  24.         "Left",
  25.         "Right",
  26.     }
  27. end
  28.  
  29. BoxRadius = 6
  30. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  31.  
  32. Inputs = InputSize+1
  33. Outputs = #ButtonNames
  34.  
  35. Population = 300
  36. DeltaDisjoint = 2.0
  37. DeltaWeights = 0.4
  38. DeltaThreshold = 1.0
  39.  
  40. StaleSpecies = 15
  41.  
  42. MutateConnectionsChance = 0.25
  43. PerturbChance = 0.90
  44. CrossoverChance = 0.75
  45. LinkMutationChance = 2.0
  46. NodeMutationChance = 0.50
  47. BiasMutationChance = 0.40
  48. StepSize = 0.1
  49. DisableMutationChance = 0.4
  50. EnableMutationChance = 0.2
  51.  
  52. TimeoutConstant = 20
  53.  
  54. MaxNodes = 1000000
  55.  
  56. restart = 0
  57.  
  58. function getPositions()
  59.     if gameinfo.getromname() == "Super Mario World (USA)" then
  60.         marioX = memory.read_s16_le(0x94)
  61.         marioY = memory.read_s16_le(0x96)
  62.        
  63.         local layer1x = memory.read_s16_le(0x1A);
  64.         local layer1y = memory.read_s16_le(0x1C);
  65.        
  66.         screenX = marioX-layer1x
  67.         screenY = marioY-layer1y
  68.     elseif gameinfo.getromname() == "Super Mario Bros." then
  69.         marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  70.         marioY = memory.readbyte(0x03B8)+16
  71.    
  72.         screenX = memory.readbyte(0x03AD)
  73.         screenY = memory.readbyte(0x03B8)
  74.     end
  75. end
  76.  
  77. function getTile(dx, dy)
  78.     if gameinfo.getromname() == "Super Mario World (USA)" then
  79.         x = math.floor((marioX+dx+8)/16)
  80.         y = math.floor((marioY+dy)/16)
  81.        
  82.         return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  83.     elseif gameinfo.getromname() == "Super Mario Bros." then
  84.         local x = marioX + dx + 8
  85.         local y = marioY + dy - 16
  86.         local page = math.floor(x/256)%2
  87.  
  88.         local subx = math.floor((x%256)/16)
  89.         local suby = math.floor((y - 32)/16)
  90.         local addr = 0x500 + page*13*16+suby*16+subx
  91.        
  92.         if suby >= 13 or suby < 0 then
  93.             return 0
  94.         end
  95.        
  96.         if memory.readbyte(addr) ~= 0 then
  97.             return 1
  98.         else
  99.             return 0
  100.         end
  101.     end
  102. end
  103.  
  104. function getSprites()
  105.     if gameinfo.getromname() == "Super Mario World (USA)" then
  106.         local sprites = {}
  107.         for slot=0,11 do
  108.             local status = memory.readbyte(0x14C8+slot)
  109.             if status ~= 0 then
  110.                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  111.                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  112.                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  113.             end
  114.         end    
  115.        
  116.         return sprites
  117.     elseif gameinfo.getromname() == "Super Mario Bros." then
  118.         local sprites = {}
  119.         for slot=0,4 do
  120.             local enemy = memory.readbyte(0xF+slot)
  121.             if enemy ~= 0 then
  122.                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  123.                 local ey = memory.readbyte(0xCF + slot)+24
  124.                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  125.             end
  126.         end
  127.        
  128.         return sprites
  129.     end
  130. end
  131.  
  132. function getExtendedSprites()
  133.     if gameinfo.getromname() == "Super Mario World (USA)" then
  134.         local extended = {}
  135.         for slot=0,11 do
  136.             local number = memory.readbyte(0x170B+slot)
  137.             if number ~= 0 then
  138.                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  139.                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  140.                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  141.             end
  142.         end    
  143.        
  144.         return extended
  145.     elseif gameinfo.getromname() == "Super Mario Bros." then
  146.         return {}
  147.     end
  148. end
  149.  
  150. function getInputs()
  151.     getPositions()
  152.    
  153.     sprites = getSprites()
  154.     extended = getExtendedSprites()
  155.    
  156.     local inputs = {}
  157.    
  158.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  159.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  160.             inputs[#inputs+1] = 0
  161.            
  162.             tile = getTile(dx, dy)
  163.             if tile == 1 and marioY+dy < 0x1B0 then
  164.                 inputs[#inputs] = 1
  165.             end
  166.            
  167.             for i = 1,#sprites do
  168.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  169.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  170.                 if distx <= 8 and disty <= 8 then
  171.                     inputs[#inputs] = -1
  172.                 end
  173.             end
  174.  
  175.             for i = 1,#extended do
  176.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  177.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  178.                 if distx < 8 and disty < 8 then
  179.                     inputs[#inputs] = -1
  180.                 end
  181.             end
  182.         end
  183.     end
  184.    
  185.     --mariovx = memory.read_s8(0x7B)
  186.     --mariovy = memory.read_s8(0x7D)
  187.    
  188.     return inputs
  189. end
  190.  
  191. function sigmoid(x)
  192.     return 2/(1+math.exp(-4.9*x))-1
  193. end
  194.  
  195. function newInnovation()
  196.     pool.innovation = pool.innovation + 1
  197.     return pool.innovation
  198. end
  199.  
  200. function newPool()
  201.     local pool = {}
  202.     pool.species = {}
  203.     pool.generation = 0
  204.     pool.innovation = Outputs
  205.     pool.currentSpecies = 1
  206.     pool.currentGenome = 1
  207.     pool.currentFrame = 0
  208.     pool.maxFitness = 0
  209.    
  210.     return pool
  211. end
  212.  
  213. function newSpecies()
  214.     local species = {}
  215.     species.topFitness = 0
  216.     species.staleness = 0
  217.     species.genomes = {}
  218.     species.averageFitness = 0
  219.    
  220.     return species
  221. end
  222.  
  223. function newGenome()
  224.     local genome = {}
  225.     genome.genes = {}
  226.     genome.fitness = 0
  227.     genome.adjustedFitness = 0
  228.     genome.network = {}
  229.     genome.maxneuron = 0
  230.     genome.globalRank = 0
  231.     genome.mutationRates = {}
  232.     genome.mutationRates["connections"] = MutateConnectionsChance
  233.     genome.mutationRates["link"] = LinkMutationChance
  234.     genome.mutationRates["bias"] = BiasMutationChance
  235.     genome.mutationRates["node"] = NodeMutationChance
  236.     genome.mutationRates["enable"] = EnableMutationChance
  237.     genome.mutationRates["disable"] = DisableMutationChance
  238.     genome.mutationRates["step"] = StepSize
  239.    
  240.     return genome
  241. end
  242.  
  243. function copyGenome(genome)
  244.     local genome2 = newGenome()
  245.     for g=1,#genome.genes do
  246.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  247.     end
  248.     genome2.maxneuron = genome.maxneuron
  249.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  250.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  251.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  252.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  253.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  254.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  255.    
  256.     return genome2
  257. end
  258.  
  259. function basicGenome()
  260.     local genome = newGenome()
  261.     local innovation = 1
  262.  
  263.     genome.maxneuron = Inputs
  264.     mutate(genome)
  265.    
  266.     return genome
  267. end
  268.  
  269. function newGene()
  270.     local gene = {}
  271.     gene.into = 0
  272.     gene.out = 0
  273.     gene.weight = 0.0
  274.     gene.enabled = true
  275.     gene.innovation = 0
  276.    
  277.     return gene
  278. end
  279.  
  280. function copyGene(gene)
  281.     local gene2 = newGene()
  282.     gene2.into = gene.into
  283.     gene2.out = gene.out
  284.     gene2.weight = gene.weight
  285.     gene2.enabled = gene.enabled
  286.     gene2.innovation = gene.innovation
  287.    
  288.     return gene2
  289. end
  290.  
  291. function newNeuron()
  292.     local neuron = {}
  293.     neuron.incoming = {}
  294.     neuron.value = 0.0
  295.    
  296.     return neuron
  297. end
  298.  
  299. function generateNetwork(genome)
  300.     local network = {}
  301.     network.neurons = {}
  302.    
  303.     for i=1,Inputs do
  304.         network.neurons[i] = newNeuron()
  305.     end
  306.    
  307.     for o=1,Outputs do
  308.         network.neurons[MaxNodes+o] = newNeuron()
  309.     end
  310.    
  311.     table.sort(genome.genes, function (a,b)
  312.         return (a.out <b> 0 then
  313.             neuron.value = sigmoid(sum)
  314.         end
  315.     end
  316.    
  317.     local outputs = {}
  318.     for o=1,Outputs do
  319.         local button = "P1 " .. ButtonNames[o]
  320.         if network.neurons[MaxNodes+o].value > 0 then
  321.             outputs[button] = true
  322.         else
  323.             outputs[button] = false
  324.         end
  325.     end
  326.    
  327.     return outputs
  328. end
  329.  
  330. function crossover(g1, g2)
  331.     -- Make sure g1 is the higher fitness genome
  332.     if g2.fitness > g1.fitness then
  333.         tempg = g1
  334.         g1 = g2
  335.         g2 = tempg
  336.     end
  337.  
  338.     local child = newGenome()
  339.    
  340.     local innovations2 = {}
  341.     for i=1,#g2.genes do
  342.         local gene = g2.genes[i]
  343.         innovations2[gene.innovation] = gene
  344.     end
  345.    
  346.     for i=1,#g1.genes do
  347.         local gene1 = g1.genes[i]
  348.         local gene2 = innovations2[gene1.innovation]
  349.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  350.             table.insert(child.genes, copyGene(gene2))
  351.         else
  352.             table.insert(child.genes, copyGene(gene1))
  353.         end
  354.     end
  355.    
  356.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  357.    
  358.     for mutation,rate in pairs(g1.mutationRates) do
  359.         child.mutationRates[mutation] = rate
  360.     end
  361.    
  362.     return child
  363. end
  364.  
  365. function randomNeuron(genes, nonInput)
  366.     local neurons = {}
  367.     if not nonInput then
  368.         for i=1,Inputs do
  369.             neurons[i] = true
  370.         end
  371.     end
  372.     for o=1,Outputs do
  373.         neurons[MaxNodes+o] = true
  374.     end
  375.     for i=1,#genes do
  376.         if (not nonInput) or genes[i].into > Inputs then
  377.             neurons[genes[i].into] = true
  378.         end
  379.         if (not nonInput) or genes[i].out > Inputs then
  380.             neurons[genes[i].out] = true
  381.         end
  382.     end
  383.  
  384.     local count = 0
  385.     for _,_ in pairs(neurons) do
  386.         count = count + 1
  387.     end
  388.     local n = math.random(1, count)
  389.    
  390.     for k,v in pairs(neurons) do
  391.         n = n-1
  392.         if n == 0 then
  393.             return k
  394.         end
  395.     end
  396.    
  397.     return 0
  398. end
  399.  
  400. function containsLink(genes, link)
  401.     for i=1,#genes do
  402.         local gene = genes[i]
  403.         if gene.into == link.into and gene.out == link.out then
  404.             return true
  405.         end
  406.     end
  407. end
  408.  
  409. function pointMutate(genome)
  410.     local step = genome.mutationRates["step"]
  411.    
  412.     for i=1,#genome.genes do
  413.         local gene = genome.genes[i]
  414.         if math.random() < PerturbChance then
  415.             gene.weight = gene.weight + math.random() * step*2 - step
  416.         else
  417.             gene.weight = math.random()*4-2
  418.         end
  419.     end
  420. end
  421.  
  422. function linkMutate(genome, forceBias)
  423.     local neuron1 = randomNeuron(genome.genes, false)
  424.     local neuron2 = randomNeuron(genome.genes, true)
  425.      
  426.     local newLink = newGene()
  427.     if neuron1 <= Inputs and neuron2 <= Inputs then
  428.         --Both input nodes
  429.         return
  430.     end
  431.     if neuron2 <= Inputs then
  432.         -- Swap output and input
  433.         local temp = neuron1
  434.         neuron1 = neuron2
  435.         neuron2 = temp
  436.     end
  437.  
  438.     newLink.into = neuron1
  439.     newLink.out = neuron2
  440.     if forceBias then
  441.         newLink.into = Inputs
  442.     end
  443.    
  444.     if containsLink(genome.genes, newLink) then
  445.         return
  446.     end
  447.     newLink.innovation = newInnovation()
  448.     newLink.weight = math.random()*4-2
  449.    
  450.     table.insert(genome.genes, newLink)
  451. end
  452.  
  453. function nodeMutate(genome)
  454.     if #genome.genes == 0 then
  455.         return
  456.     end
  457.  
  458.     genome.maxneuron = genome.maxneuron + 1
  459.  
  460.     local gene = genome.genes[math.random(1,#genome.genes)]
  461.     if not gene.enabled then
  462.         return
  463.     end
  464.     gene.enabled = false
  465.    
  466.     local gene1 = copyGene(gene)
  467.     gene1.out = genome.maxneuron
  468.     gene1.weight = 1.0
  469.     gene1.innovation = newInnovation()
  470.     gene1.enabled = true
  471.     table.insert(genome.genes, gene1)
  472.    
  473.     local gene2 = copyGene(gene)
  474.     gene2.into = genome.maxneuron
  475.     gene2.innovation = newInnovation()
  476.     gene2.enabled = true
  477.     table.insert(genome.genes, gene2)
  478. end
  479.  
  480. function enableDisableMutate(genome, enable)
  481.     local candidates = {}
  482.     for _,gene in pairs(genome.genes) do
  483.         if gene.enabled == not enable then
  484.             table.insert(candidates, gene)
  485.         end
  486.     end
  487.    
  488.     if #candidates == 0 then
  489.         return
  490.     end
  491.    
  492.     local gene = candidates[math.random(1,#candidates)]
  493.     gene.enabled = not gene.enabled
  494. end
  495.  
  496. function mutate(genome)
  497.     for mutation,rate in pairs(genome.mutationRates) do
  498.         if math.random(1,2) == 1 then
  499.             genome.mutationRates[mutation] = 0.95*rate
  500.         else
  501.             genome.mutationRates[mutation] = 1.05263*rate
  502.         end
  503.     end
  504.  
  505.     if math.random() <genome> 0 do
  506.         if math.random() <p> 0 do
  507.         if math.random() <p> 0 do
  508.         if math.random() <p> 0 do
  509.         if math.random() <p> 0 do
  510.         if math.random() < p then
  511.             enableDisableMutate(genome, false)
  512.         end
  513.         p = p - 1
  514.     end
  515. end
  516.  
  517. function disjoint(genes1, genes2)
  518.     local i1 = {}
  519.     for i = 1,#genes1 do
  520.         local gene = genes1[i]
  521.         i1[gene.innovation] = true
  522.     end
  523.  
  524.     local i2 = {}
  525.     for i = 1,#genes2 do
  526.         local gene = genes2[i]
  527.         i2[gene.innovation] = true
  528.     end
  529.    
  530.     local disjointGenes = 0
  531.     for i = 1,#genes1 do
  532.         local gene = genes1[i]
  533.         if not i2[gene.innovation] then
  534.             disjointGenes = disjointGenes+1
  535.         end
  536.     end
  537.    
  538.     for i = 1,#genes2 do
  539.         local gene = genes2[i]
  540.         if not i1[gene.innovation] then
  541.             disjointGenes = disjointGenes+1
  542.         end
  543.     end
  544.    
  545.     local n = math.max(#genes1, #genes2)
  546.    
  547.     return disjointGenes / n
  548. end
  549.  
  550. function weights(genes1, genes2)
  551.     local i2 = {}
  552.     for i = 1,#genes2 do
  553.         local gene = genes2[i]
  554.         i2[gene.innovation] = gene
  555.     end
  556.  
  557.     local sum = 0
  558.     local coincident = 0
  559.     for i = 1,#genes1 do
  560.         local gene = genes1[i]
  561.         if i2[gene.innovation] ~= nil then
  562.             local gene2 = i2[gene.innovation]
  563.             sum = sum + math.abs(gene.weight - gene2.weight)
  564.             coincident = coincident + 1
  565.         end
  566.     end
  567.    
  568.     return sum / coincident
  569. end
  570.    
  571. function sameSpecies(genome1, genome2)
  572.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  573.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  574.     return dd + dw < DeltaThreshold
  575. end
  576.  
  577. function rankGlobally()
  578.     local global = {}
  579.     for s = 1,#pool.species do
  580.         local species = pool.species[s]
  581.         for g = 1,#species.genomes do
  582.             table.insert(global, species.genomes[g])
  583.         end
  584.     end
  585.     table.sort(global, function (a,b)
  586.         return (a.fitness <b> b.fitness)
  587.         end)
  588.        
  589.         local remaining = math.ceil(#species.genomes/2)
  590.         if cutToOne then
  591.             remaining = 1
  592.         end
  593.         while #species.genomes > remaining do
  594.             table.remove(species.genomes)
  595.         end
  596.     end
  597. end
  598.  
  599. function breedChild(species)
  600.     local child = {}
  601.     if math.random() <CrossoverChance> b.fitness)
  602.         end)
  603.        
  604.         if species.genomes[1].fitness > species.topFitness then
  605.             species.topFitness = species.genomes[1].fitness
  606.             species.staleness = 0
  607.         else
  608.             species.staleness = species.staleness + 1
  609.         end
  610.         if species.staleness <StaleSpecies>= pool.maxFitness then
  611.             table.insert(survived, species)
  612.         end
  613.     end
  614.  
  615.     pool.species = survived
  616. end
  617.  
  618. function removeWeakSpecies()
  619.     local survived = {}
  620.  
  621.     local sum = totalAverageFitness()
  622.     for s = 1,#pool.species do
  623.         local species = pool.species[s]
  624.         breed = math.floor(species.averageFitness / sum * Population)
  625.         if breed >= 1 then
  626.             table.insert(survived, species)
  627.         end
  628.     end
  629.  
  630.     pool.species = survived
  631. end
  632.  
  633.  
  634. function addToSpecies(child)
  635.     local foundSpecies = false
  636.     for s=1,#pool.species do
  637.         local species = pool.species[s]
  638.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  639.             table.insert(species.genomes, child)
  640.             foundSpecies = true
  641.         end
  642.     end
  643.    
  644.     if not foundSpecies then
  645.         local childSpecies = newSpecies()
  646.         table.insert(childSpecies.genomes, child)
  647.         table.insert(pool.species, childSpecies)
  648.     end
  649. end
  650.  
  651. function newGeneration()
  652.     cullSpecies(false) -- Cull the bottom half of each species
  653.     rankGlobally()
  654.     removeStaleSpecies()
  655.     rankGlobally()
  656.     for s = 1,#pool.species do
  657.         local species = pool.species[s]
  658.         calculateAverageFitness(species)
  659.     end
  660.     removeWeakSpecies()
  661.     local sum = totalAverageFitness()
  662.     local children = {}
  663.     for s = 1,#pool.species do
  664.         local species = pool.species[s]
  665.         breed = math.floor(species.averageFitness / sum * Population) - 1
  666.         for i=1,breed do
  667.             table.insert(children, breedChild(species))
  668.         end
  669.     end
  670.     cullSpecies(true) -- Cull all but the top member of each species
  671.     while #children + #pool.species <Population> #pool.species[pool.currentSpecies].genomes then
  672.         pool.currentGenome = 1
  673.         pool.currentSpecies = pool.currentSpecies+1
  674.         if pool.currentSpecies > #pool.species then
  675.             newGeneration()
  676.             pool.currentSpecies = 1
  677.         end
  678.     end
  679. end
  680.  
  681. function fitnessAlreadyMeasured()
  682.     local species = pool.species[pool.currentSpecies]
  683.     local genome = species.genomes[pool.currentGenome]
  684.    
  685.     return genome.fitness ~= 0
  686. end
  687.  
  688. function displayGenome(genome)
  689.     local network = genome.network
  690.     local cells = {}
  691.     local i = 1
  692.     local cell = {}
  693.     for dy=-BoxRadius,BoxRadius do
  694.         for dx=-BoxRadius,BoxRadius do
  695.             cell = {}
  696.             cell.x = 50+5*dx
  697.             cell.y = 70+5*dy
  698.             cell.value = network.neurons[i].value
  699.             cells[i] = cell
  700.             i = i + 1
  701.         end
  702.     end
  703.     local biasCell = {}
  704.     biasCell.x = 80
  705.     biasCell.y = 110
  706.     biasCell.value = network.neurons[Inputs].value
  707.     cells[Inputs] = biasCell
  708.    
  709.     for o = 1,Outputs do
  710.         cell = {}
  711.         cell.x = 220
  712.         cell.y = 30 + 8 * o
  713.         cell.value = network.neurons[MaxNodes + o].value
  714.         cells[MaxNodes+o] = cell
  715.         local color
  716.         if cell.value > 0 then
  717.             color = 0xFF0000FF
  718.         else
  719.             color = 0xFF000000
  720.         end
  721.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  722.     end
  723.    
  724.     for n,neuron in pairs(network.neurons) do
  725.         cell = {}
  726.         if n > Inputs and n <MaxNodes> Inputs and gene.into <MaxNodes>= c2.x then
  727.                         c1.x = c1.x - 40
  728.                     end
  729.                     if c1.x <90> 220 then
  730.                         c1.x = 220
  731.                     end
  732.                     c1.y = 0.75*c1.y + 0.25*c2.y
  733.                    
  734.                 end
  735.                 if gene.out > Inputs and gene.out <MaxNodes>= c2.x then
  736.                         c2.x = c2.x + 40
  737.                     end
  738.                     if c2.x <90> 220 then
  739.                         c2.x = 220
  740.                     end
  741.                     c2.y = 0.25*c1.y + 0.75*c2.y
  742.                 end
  743.             end
  744.         end
  745.     end
  746.    
  747.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  748.     for n,cell in pairs(cells) do
  749.         if n > Inputs or cell.value ~= 0 then
  750.             local color = math.floor((cell.value+1)/2*256)
  751.             if color > 255 then color = 255 end
  752.             if color <0> 0 then
  753.                 color = opacity + 0x8000 + 0x10000*color
  754.             else
  755.                 color = opacity + 0x800000 + 0x100*color
  756.             end
  757.             gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  758.         end
  759.     end
  760.    
  761.     gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  762.    
  763.     if forms.ischecked(showMutationRates) then
  764.         local pos = 100
  765.         for mutation,rate in pairs(genome.mutationRates) do
  766.             gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  767.             pos = pos + 8
  768.         end
  769.     end
  770. end
  771.  
  772. function restartCount()
  773.     restart = restart + 1
  774. end
  775.  
  776. function savePool()
  777.     local filename = forms.gettext(saveLoadFile)
  778.     writeFile(filename)
  779. end
  780.  
  781. function writeFile(filename)
  782.         local file = io.open(filename, "w")
  783.     file:write(pool.generation .. "\n")
  784.     file:write(pool.maxFitness .. "\n")
  785.     file:write(#pool.species .. "\n")
  786.         for n,species in pairs(pool.species) do
  787.         file:write(species.topFitness .. "\n")
  788.         file:write(species.staleness .. "\n")
  789.         file:write(#species.genomes .. "\n")
  790.         for m,genome in pairs(species.genomes) do
  791.             file:write(genome.fitness .. "\n")
  792.             file:write(genome.maxneuron .. "\n")
  793.             for mutation,rate in pairs(genome.mutationRates) do
  794.                 file:write(mutation .. "\n")
  795.                 file:write(rate .. "\n")
  796.             end
  797.             file:write("done\n")
  798.            
  799.             file:write(#genome.genes .. "\n")
  800.             for l,gene in pairs(genome.genes) do
  801.                 file:write(gene.into .. " ")
  802.                 file:write(gene.out .. " ")
  803.                 file:write(gene.weight .. " ")
  804.                 file:write(gene.innovation .. " ")
  805.                 if(gene.enabled) then
  806.                     file:write("1\n")
  807.                 else
  808.                     file:write("0\n")
  809.                 end
  810.             end
  811.         end
  812.         end
  813.         file:close()
  814. end
  815.  
  816. function loadPool()
  817.     local filename = forms.gettext(saveLoadFile)
  818.     loadFile(filename)
  819. end
  820.  
  821. function loadFile(filename)
  822.         local file = io.open(filename, "r")
  823.     pool = newPool()
  824.     pool.generation = file:read("*number")
  825.     pool.maxFitness = file:read("*number")
  826.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  827.         local numSpecies = file:read("*number")
  828.         for s=1,numSpecies do
  829.         local species = newSpecies()
  830.         table.insert(pool.species, species)
  831.         species.topFitness = file:read("*number")
  832.         species.staleness = file:read("*number")
  833.         local numGenomes = file:read("*number")
  834.         for g=1,numGenomes do
  835.             local genome = newGenome()
  836.             table.insert(species.genomes, genome)
  837.             genome.fitness = file:read("*number")
  838.             genome.maxneuron = file:read("*number")
  839.             local line = file:read("*line")
  840.             while line ~= "done" do
  841.                 genome.mutationRates[line] = file:read("*number")
  842.                 line = file:read("*line")
  843.             end
  844.             local numGenes = file:read("*number")
  845.             for n=1,numGenes do
  846.                 local gene = newGene()
  847.                 table.insert(genome.genes, gene)
  848.                 local enabled
  849.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  850.                 if enabled == 0 then
  851.                     gene.enabled = false
  852.                 else
  853.                     gene.enabled = true
  854.                 end
  855.                
  856.             end
  857.         end
  858.         end
  859.        
  860.         file:close()
  861.    
  862.     while fitnessAlreadyMeasured() do
  863.         nextGenome()
  864.     end
  865.     initializeRun()
  866.     pool.currentFrame = pool.currentFrame + 1
  867. end
  868.  
  869. function playTop()
  870.     local maxfitness = 0
  871.     local maxs, maxg
  872.     for s,species in pairs(pool.species) do
  873.         for g,genome in pairs(species.genomes) do
  874.             if genome.fitness > maxfitness then
  875.                 maxfitness = genome.fitness
  876.                 maxs = s
  877.                 maxg = g
  878.             end
  879.         end
  880.     end
  881.    
  882.     pool.currentSpecies = maxs
  883.     pool.currentGenome = maxg
  884.     pool.maxFitness = maxfitness
  885.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  886.     initializeRun()
  887.     pool.currentFrame = pool.currentFrame + 1
  888.     return
  889. end
  890.  
  891. function onExit()
  892.     forms.destroy(form)
  893. end
  894.  
  895. writeFile("temp1.pool")
  896.  
  897. event.onexit(onExit)
  898. event.onloadstate(restartCount)
  899.  
  900. form = forms.newform(200, 260, "Fitness")
  901. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  902. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  903. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  904. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  905. saveButton = forms.button(form, "Save", savePool, 5, 102)
  906. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  907. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  908. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  909. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  910. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  911.  
  912.  
  913. while true do
  914.     local backgroundColor = 0xD0FFFFFF
  915.     if not forms.ischecked(hideBanner) then
  916.         gui.drawBox(0, 0, 300, 30, backgroundColor, backgroundColor)
  917.     end
  918.  
  919.     local species = pool.species[pool.currentSpecies]
  920.     local genome = species.genomes[pool.currentGenome]
  921.    
  922.     if forms.ischecked(showNetwork) then
  923.         displayGenome(genome)
  924.     end
  925.    
  926.     if pool.currentFrame%5 == 0 then
  927.         evaluateCurrent()
  928.     end
  929.  
  930.     joypad.set(controller)
  931.  
  932.     getPositions()
  933.     if marioX > rightmost then
  934.         rightmost = marioX
  935.         timeout = TimeoutConstant
  936.     end
  937.    
  938.     timeout = timeout - 1
  939.    
  940.    
  941.     local timeoutBonus = pool.currentFrame / 4
  942.     if timeout + timeoutBonus <0> 4816 then
  943.             fitness = fitness + 1000
  944.         end
  945.         if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  946.             fitness = fitness + 1000
  947.         end
  948.         if fitness == 0 then
  949.             fitness = -1
  950.         end
  951.         genome.fitness = fitness
  952.        
  953.         if fitness > pool.maxFitness then
  954.             pool.maxFitness = fitness
  955.             forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  956.             writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  957.         end
  958.        
  959.         console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  960.         pool.currentSpecies = 1
  961.         pool.currentGenome = 1
  962.         while fitnessAlreadyMeasured() do
  963.             nextGenome()
  964.         end
  965.         initializeRun()
  966.     end
  967.  
  968.     local measured = 0
  969.     local total = 0
  970.     for _,species in pairs(pool.species) do
  971.         for _,genome in pairs(species.genomes) do
  972.             total = total + 1
  973.             if genome.fitness ~= 0 then
  974.                 measured = measured + 1
  975.             end
  976.         end
  977.     end
  978.     if not forms.ischecked(hideBanner) then
  979.         gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 10)
  980.         gui.drawText(0, 10, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 10)
  981.         gui.drawText(100, 10, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 10)
  982.         gui.drawText(0, 20, "Restarts: " .. restart, 0xFF000000, 10)
  983.     end
  984.        
  985.     pool.currentFrame = pool.currentFrame + 1
  986.  
  987.     emu.frameadvance();
  988. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement