Advertisement
Guest User

Untitled

a guest
Jun 15th, 2015
417
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 23.61 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4.  
  5. if gameinfo.getromname() == "Super Mario World (USA)" then
  6.     Filename = "YI1.state"
  7.     ButtonNames = {
  8.         "A",
  9.         "B",
  10.         "X",
  11.         "Y",
  12.         "Up",
  13.         "Down",
  14.         "Left",
  15.         "Right",
  16.     }
  17. elseif gameinfo.getromname() == "Super Mario Bros." then
  18.     Filename = "SMB1-1.state"
  19.     ButtonNames = {
  20.         "A",
  21.         "B",
  22.         "Up",
  23.         "Down",
  24.         "Left",
  25.         "Right",
  26.     }
  27. end
  28.  
  29. BoxRadius = 6
  30. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  31.  
  32. Inputs = InputSize+1
  33. Outputs = #ButtonNames
  34.  
  35. Population = 300
  36. DeltaDisjoint = 2.0
  37. DeltaWeights = 0.4
  38. DeltaThreshold = 1.0
  39.  
  40. StaleSpecies = 15
  41.  
  42. MutateConnectionsChance = 0.25
  43. PerturbChance = 0.90
  44. CrossoverChance = 0.75
  45. LinkMutationChance = 2.0
  46. NodeMutationChance = 0.50
  47. BiasMutationChance = 0.40
  48. StepSize = 0.1
  49. DisableMutationChance = 0.4
  50. EnableMutationChance = 0.2
  51.  
  52. TimeoutConstant = 20
  53.  
  54. MaxNodes = 1000000
  55.  
  56. restart = 0
  57.  
  58. function countLoad()
  59.     restart = restart + 1
  60.     forms.settext(restartLabel, "Restarts: " .. restart)
  61. end
  62.  
  63. function getPositions()
  64.     if gameinfo.getromname() == "Super Mario World (USA)" then
  65.         marioX = memory.read_s16_le(0x94)
  66.         marioY = memory.read_s16_le(0x96)
  67.        
  68.         local layer1x = memory.read_s16_le(0x1A);
  69.         local layer1y = memory.read_s16_le(0x1C);
  70.        
  71.         screenX = marioX-layer1x
  72.         screenY = marioY-layer1y
  73.     elseif gameinfo.getromname() == "Super Mario Bros." then
  74.         marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  75.         marioY = memory.readbyte(0x03B8)+16
  76.    
  77.         screenX = memory.readbyte(0x03AD)
  78.         screenY = memory.readbyte(0x03B8)
  79.     end
  80. end
  81.  
  82. function getTile(dx, dy)
  83.     if gameinfo.getromname() == "Super Mario World (USA)" then
  84.         x = math.floor((marioX+dx+8)/16)
  85.         y = math.floor((marioY+dy)/16)
  86.        
  87.         return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  88.     elseif gameinfo.getromname() == "Super Mario Bros." then
  89.         local x = marioX + dx + 8
  90.         local y = marioY + dy - 16
  91.         local page = math.floor(x/256)%2
  92.  
  93.         local subx = math.floor((x%256)/16)
  94.         local suby = math.floor((y - 32)/16)
  95.         local addr = 0x500 + page*13*16+suby*16+subx
  96.        
  97.         if suby >= 13 or suby < 0 then
  98.             return 0
  99.         end
  100.        
  101.         if memory.readbyte(addr) ~= 0 then
  102.             return 1
  103.         else
  104.             return 0
  105.         end
  106.     end
  107. end
  108.  
  109. function getSprites()
  110.     if gameinfo.getromname() == "Super Mario World (USA)" then
  111.         local sprites = {}
  112.         for slot=0,11 do
  113.             local status = memory.readbyte(0x14C8+slot)
  114.             if status ~= 0 then
  115.                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  116.                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  117.                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  118.             end
  119.         end    
  120.        
  121.         return sprites
  122.     elseif gameinfo.getromname() == "Super Mario Bros." then
  123.         local sprites = {}
  124.         for slot=0,4 do
  125.             local enemy = memory.readbyte(0xF+slot)
  126.             if enemy ~= 0 then
  127.                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  128.                 local ey = memory.readbyte(0xCF + slot)+24
  129.                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  130.             end
  131.         end
  132.        
  133.         return sprites
  134.     end
  135. end
  136.  
  137. function getExtendedSprites()
  138.     if gameinfo.getromname() == "Super Mario World (USA)" then
  139.         local extended = {}
  140.         for slot=0,11 do
  141.             local number = memory.readbyte(0x170B+slot)
  142.             if number ~= 0 then
  143.                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  144.                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  145.                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  146.             end
  147.         end    
  148.        
  149.         return extended
  150.     elseif gameinfo.getromname() == "Super Mario Bros." then
  151.         return {}
  152.     end
  153. end
  154.  
  155. function getInputs()
  156.     getPositions()
  157.    
  158.     sprites = getSprites()
  159.     extended = getExtendedSprites()
  160.    
  161.     local inputs = {}
  162.    
  163.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  164.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  165.             inputs[#inputs+1] = 0
  166.            
  167.             tile = getTile(dx, dy)
  168.             if tile == 1 and marioY+dy < 0x1B0 then
  169.                 inputs[#inputs] = 1
  170.             end
  171.            
  172.             for i = 1,#sprites do
  173.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  174.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  175.                 if distx <= 8 and disty <= 8 then
  176.                     inputs[#inputs] = -1
  177.                 end
  178.             end
  179.  
  180.             for i = 1,#extended do
  181.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  182.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  183.                 if distx < 8 and disty < 8 then
  184.                     inputs[#inputs] = -1
  185.                 end
  186.             end
  187.         end
  188.     end
  189.    
  190.     --mariovx = memory.read_s8(0x7B)
  191.     --mariovy = memory.read_s8(0x7D)
  192.    
  193.     return inputs
  194. end
  195.  
  196. function sigmoid(x)
  197.     return 2/(1+math.exp(-4.9*x))-1
  198. end
  199.  
  200. function newInnovation()
  201.     pool.innovation = pool.innovation + 1
  202.     return pool.innovation
  203. end
  204.  
  205. function newPool()
  206.     local pool = {}
  207.     pool.species = {}
  208.     pool.generation = 0
  209.     pool.innovation = Outputs
  210.     pool.currentSpecies = 1
  211.     pool.currentGenome = 1
  212.     pool.currentFrame = 0
  213.     pool.maxFitness = 0
  214.    
  215.     return pool
  216. end
  217.  
  218. function newSpecies()
  219.     local species = {}
  220.     species.topFitness = 0
  221.     species.staleness = 0
  222.     species.genomes = {}
  223.     species.averageFitness = 0
  224.    
  225.     return species
  226. end
  227.  
  228. function newGenome()
  229.     local genome = {}
  230.     genome.genes = {}
  231.     genome.fitness = 0
  232.     genome.adjustedFitness = 0
  233.     genome.network = {}
  234.     genome.maxneuron = 0
  235.     genome.globalRank = 0
  236.     genome.mutationRates = {}
  237.     genome.mutationRates["connections"] = MutateConnectionsChance
  238.     genome.mutationRates["link"] = LinkMutationChance
  239.     genome.mutationRates["bias"] = BiasMutationChance
  240.     genome.mutationRates["node"] = NodeMutationChance
  241.     genome.mutationRates["enable"] = EnableMutationChance
  242.     genome.mutationRates["disable"] = DisableMutationChance
  243.     genome.mutationRates["step"] = StepSize
  244.    
  245.     return genome
  246. end
  247.  
  248. function copyGenome(genome)
  249.     local genome2 = newGenome()
  250.     for g=1,#genome.genes do
  251.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  252.     end
  253.     genome2.maxneuron = genome.maxneuron
  254.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  255.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  256.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  257.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  258.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  259.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  260.    
  261.     return genome2
  262. end
  263.  
  264. function basicGenome()
  265.     local genome = newGenome()
  266.     local innovation = 1
  267.  
  268.     genome.maxneuron = Inputs
  269.     mutate(genome)
  270.    
  271.     return genome
  272. end
  273.  
  274. function newGene()
  275.     local gene = {}
  276.     gene.into = 0
  277.     gene.out = 0
  278.     gene.weight = 0.0
  279.     gene.enabled = true
  280.     gene.innovation = 0
  281.    
  282.     return gene
  283. end
  284.  
  285. function copyGene(gene)
  286.     local gene2 = newGene()
  287.     gene2.into = gene.into
  288.     gene2.out = gene.out
  289.     gene2.weight = gene.weight
  290.     gene2.enabled = gene.enabled
  291.     gene2.innovation = gene.innovation
  292.    
  293.     return gene2
  294. end
  295.  
  296. function newNeuron()
  297.     local neuron = {}
  298.     neuron.incoming = {}
  299.     neuron.value = 0.0
  300.    
  301.     return neuron
  302. end
  303.  
  304. function generateNetwork(genome)
  305.     local network = {}
  306.     network.neurons = {}
  307.    
  308.     for i=1,Inputs do
  309.         network.neurons[i] = newNeuron()
  310.     end
  311.    
  312.     for o=1,Outputs do
  313.         network.neurons[MaxNodes+o] = newNeuron()
  314.     end
  315.    
  316.     table.sort(genome.genes, function (a,b)
  317.         return (a.out <b> 0 then
  318.             neuron.value = sigmoid(sum)
  319.         end
  320.     end
  321.    
  322.     local outputs = {}
  323.     for o=1,Outputs do
  324.         local button = "P1 " .. ButtonNames[o]
  325.         if network.neurons[MaxNodes+o].value > 0 then
  326.             outputs[button] = true
  327.         else
  328.             outputs[button] = false
  329.         end
  330.     end
  331.    
  332.     return outputs
  333. end
  334.  
  335. function crossover(g1, g2)
  336.     -- Make sure g1 is the higher fitness genome
  337.     if g2.fitness > g1.fitness then
  338.         tempg = g1
  339.         g1 = g2
  340.         g2 = tempg
  341.     end
  342.  
  343.     local child = newGenome()
  344.    
  345.     local innovations2 = {}
  346.     for i=1,#g2.genes do
  347.         local gene = g2.genes[i]
  348.         innovations2[gene.innovation] = gene
  349.     end
  350.    
  351.     for i=1,#g1.genes do
  352.         local gene1 = g1.genes[i]
  353.         local gene2 = innovations2[gene1.innovation]
  354.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  355.             table.insert(child.genes, copyGene(gene2))
  356.         else
  357.             table.insert(child.genes, copyGene(gene1))
  358.         end
  359.     end
  360.    
  361.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  362.    
  363.     for mutation,rate in pairs(g1.mutationRates) do
  364.         child.mutationRates[mutation] = rate
  365.     end
  366.    
  367.     return child
  368. end
  369.  
  370. function randomNeuron(genes, nonInput)
  371.     local neurons = {}
  372.     if not nonInput then
  373.         for i=1,Inputs do
  374.             neurons[i] = true
  375.         end
  376.     end
  377.     for o=1,Outputs do
  378.         neurons[MaxNodes+o] = true
  379.     end
  380.     for i=1,#genes do
  381.         if (not nonInput) or genes[i].into > Inputs then
  382.             neurons[genes[i].into] = true
  383.         end
  384.         if (not nonInput) or genes[i].out > Inputs then
  385.             neurons[genes[i].out] = true
  386.         end
  387.     end
  388.  
  389.     local count = 0
  390.     for _,_ in pairs(neurons) do
  391.         count = count + 1
  392.     end
  393.     local n = math.random(1, count)
  394.    
  395.     for k,v in pairs(neurons) do
  396.         n = n-1
  397.         if n == 0 then
  398.             return k
  399.         end
  400.     end
  401.    
  402.     return 0
  403. end
  404.  
  405. function containsLink(genes, link)
  406.     for i=1,#genes do
  407.         local gene = genes[i]
  408.         if gene.into == link.into and gene.out == link.out then
  409.             return true
  410.         end
  411.     end
  412. end
  413.  
  414. function pointMutate(genome)
  415.     local step = genome.mutationRates["step"]
  416.    
  417.     for i=1,#genome.genes do
  418.         local gene = genome.genes[i]
  419.         if math.random() < PerturbChance then
  420.             gene.weight = gene.weight + math.random() * step*2 - step
  421.         else
  422.             gene.weight = math.random()*4-2
  423.         end
  424.     end
  425. end
  426.  
  427. function linkMutate(genome, forceBias)
  428.     local neuron1 = randomNeuron(genome.genes, false)
  429.     local neuron2 = randomNeuron(genome.genes, true)
  430.      
  431.     local newLink = newGene()
  432.     if neuron1 <= Inputs and neuron2 <= Inputs then
  433.         --Both input nodes
  434.         return
  435.     end
  436.     if neuron2 <= Inputs then
  437.         -- Swap output and input
  438.         local temp = neuron1
  439.         neuron1 = neuron2
  440.         neuron2 = temp
  441.     end
  442.  
  443.     newLink.into = neuron1
  444.     newLink.out = neuron2
  445.     if forceBias then
  446.         newLink.into = Inputs
  447.     end
  448.    
  449.     if containsLink(genome.genes, newLink) then
  450.         return
  451.     end
  452.     newLink.innovation = newInnovation()
  453.     newLink.weight = math.random()*4-2
  454.    
  455.     table.insert(genome.genes, newLink)
  456. end
  457.  
  458. function nodeMutate(genome)
  459.     if #genome.genes == 0 then
  460.         return
  461.     end
  462.  
  463.     genome.maxneuron = genome.maxneuron + 1
  464.  
  465.     local gene = genome.genes[math.random(1,#genome.genes)]
  466.     if not gene.enabled then
  467.         return
  468.     end
  469.     gene.enabled = false
  470.    
  471.     local gene1 = copyGene(gene)
  472.     gene1.out = genome.maxneuron
  473.     gene1.weight = 1.0
  474.     gene1.innovation = newInnovation()
  475.     gene1.enabled = true
  476.     table.insert(genome.genes, gene1)
  477.    
  478.     local gene2 = copyGene(gene)
  479.     gene2.into = genome.maxneuron
  480.     gene2.innovation = newInnovation()
  481.     gene2.enabled = true
  482.     table.insert(genome.genes, gene2)
  483. end
  484.  
  485. function enableDisableMutate(genome, enable)
  486.     local candidates = {}
  487.     for _,gene in pairs(genome.genes) do
  488.         if gene.enabled == not enable then
  489.             table.insert(candidates, gene)
  490.         end
  491.     end
  492.    
  493.     if #candidates == 0 then
  494.         return
  495.     end
  496.    
  497.     local gene = candidates[math.random(1,#candidates)]
  498.     gene.enabled = not gene.enabled
  499. end
  500.  
  501. function mutate(genome)
  502.     for mutation,rate in pairs(genome.mutationRates) do
  503.         if math.random(1,2) == 1 then
  504.             genome.mutationRates[mutation] = 0.95*rate
  505.         else
  506.             genome.mutationRates[mutation] = 1.05263*rate
  507.         end
  508.     end
  509.  
  510.     if math.random() <genome> 0 do
  511.         if math.random() <p> 0 do
  512.         if math.random() <p> 0 do
  513.         if math.random() <p> 0 do
  514.         if math.random() <p> 0 do
  515.         if math.random() < p then
  516.             enableDisableMutate(genome, false)
  517.         end
  518.         p = p - 1
  519.     end
  520. end
  521.  
  522. function disjoint(genes1, genes2)
  523.     local i1 = {}
  524.     for i = 1,#genes1 do
  525.         local gene = genes1[i]
  526.         i1[gene.innovation] = true
  527.     end
  528.  
  529.     local i2 = {}
  530.     for i = 1,#genes2 do
  531.         local gene = genes2[i]
  532.         i2[gene.innovation] = true
  533.     end
  534.    
  535.     local disjointGenes = 0
  536.     for i = 1,#genes1 do
  537.         local gene = genes1[i]
  538.         if not i2[gene.innovation] then
  539.             disjointGenes = disjointGenes+1
  540.         end
  541.     end
  542.    
  543.     for i = 1,#genes2 do
  544.         local gene = genes2[i]
  545.         if not i1[gene.innovation] then
  546.             disjointGenes = disjointGenes+1
  547.         end
  548.     end
  549.    
  550.     local n = math.max(#genes1, #genes2)
  551.    
  552.     return disjointGenes / n
  553. end
  554.  
  555. function weights(genes1, genes2)
  556.     local i2 = {}
  557.     for i = 1,#genes2 do
  558.         local gene = genes2[i]
  559.         i2[gene.innovation] = gene
  560.     end
  561.  
  562.     local sum = 0
  563.     local coincident = 0
  564.     for i = 1,#genes1 do
  565.         local gene = genes1[i]
  566.         if i2[gene.innovation] ~= nil then
  567.             local gene2 = i2[gene.innovation]
  568.             sum = sum + math.abs(gene.weight - gene2.weight)
  569.             coincident = coincident + 1
  570.         end
  571.     end
  572.    
  573.     return sum / coincident
  574. end
  575.    
  576. function sameSpecies(genome1, genome2)
  577.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  578.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  579.     return dd + dw < DeltaThreshold
  580. end
  581.  
  582. function rankGlobally()
  583.     local global = {}
  584.     for s = 1,#pool.species do
  585.         local species = pool.species[s]
  586.         for g = 1,#species.genomes do
  587.             table.insert(global, species.genomes[g])
  588.         end
  589.     end
  590.     table.sort(global, function (a,b)
  591.         return (a.fitness <b> b.fitness)
  592.         end)
  593.        
  594.         local remaining = math.ceil(#species.genomes/2)
  595.         if cutToOne then
  596.             remaining = 1
  597.         end
  598.         while #species.genomes > remaining do
  599.             table.remove(species.genomes)
  600.         end
  601.     end
  602. end
  603.  
  604. function breedChild(species)
  605.     local child = {}
  606.     if math.random() <CrossoverChance> b.fitness)
  607.         end)
  608.        
  609.         if species.genomes[1].fitness > species.topFitness then
  610.             species.topFitness = species.genomes[1].fitness
  611.             species.staleness = 0
  612.         else
  613.             species.staleness = species.staleness + 1
  614.         end
  615.         if species.staleness <StaleSpecies>= pool.maxFitness then
  616.             table.insert(survived, species)
  617.         end
  618.     end
  619.  
  620.     pool.species = survived
  621. end
  622.  
  623. function removeWeakSpecies()
  624.     local survived = {}
  625.  
  626.     local sum = totalAverageFitness()
  627.     for s = 1,#pool.species do
  628.         local species = pool.species[s]
  629.         breed = math.floor(species.averageFitness / sum * Population)
  630.         if breed >= 1 then
  631.             table.insert(survived, species)
  632.         end
  633.     end
  634.  
  635.     pool.species = survived
  636. end
  637.  
  638.  
  639. function addToSpecies(child)
  640.     local foundSpecies = false
  641.     for s=1,#pool.species do
  642.         local species = pool.species[s]
  643.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  644.             table.insert(species.genomes, child)
  645.             foundSpecies = true
  646.         end
  647.     end
  648.    
  649.     if not foundSpecies then
  650.         local childSpecies = newSpecies()
  651.         table.insert(childSpecies.genomes, child)
  652.         table.insert(pool.species, childSpecies)
  653.     end
  654. end
  655.  
  656. function newGeneration()
  657.     cullSpecies(false) -- Cull the bottom half of each species
  658.     rankGlobally()
  659.     removeStaleSpecies()
  660.     rankGlobally()
  661.     for s = 1,#pool.species do
  662.         local species = pool.species[s]
  663.         calculateAverageFitness(species)
  664.     end
  665.     removeWeakSpecies()
  666.     local sum = totalAverageFitness()
  667.     local children = {}
  668.     for s = 1,#pool.species do
  669.         local species = pool.species[s]
  670.         breed = math.floor(species.averageFitness / sum * Population) - 1
  671.         for i=1,breed do
  672.             table.insert(children, breedChild(species))
  673.         end
  674.     end
  675.     cullSpecies(true) -- Cull all but the top member of each species
  676.     while #children + #pool.species <Population> #pool.species[pool.currentSpecies].genomes then
  677.         pool.currentGenome = 1
  678.         pool.currentSpecies = pool.currentSpecies+1
  679.         if pool.currentSpecies > #pool.species then
  680.             newGeneration()
  681.             pool.currentSpecies = 1
  682.         end
  683.     end
  684. end
  685.  
  686. function fitnessAlreadyMeasured()
  687.     local species = pool.species[pool.currentSpecies]
  688.     local genome = species.genomes[pool.currentGenome]
  689.    
  690.     return genome.fitness ~= 0
  691. end
  692.  
  693. function displayGenome(genome)
  694.     local network = genome.network
  695.     local cells = {}
  696.     local i = 1
  697.     local cell = {}
  698.     for dy=-BoxRadius,BoxRadius do
  699.         for dx=-BoxRadius,BoxRadius do
  700.             cell = {}
  701.             cell.x = 50+5*dx
  702.             cell.y = 70+5*dy
  703.             cell.value = network.neurons[i].value
  704.             cells[i] = cell
  705.             i = i + 1
  706.         end
  707.     end
  708.     local biasCell = {}
  709.     biasCell.x = 80
  710.     biasCell.y = 110
  711.     biasCell.value = network.neurons[Inputs].value
  712.     cells[Inputs] = biasCell
  713.    
  714.     for o = 1,Outputs do
  715.         cell = {}
  716.         cell.x = 220
  717.         cell.y = 30 + 8 * o
  718.         cell.value = network.neurons[MaxNodes + o].value
  719.         cells[MaxNodes+o] = cell
  720.         local color
  721.         if cell.value > 0 then
  722.             color = 0xFF0000FF
  723.         else
  724.             color = 0xFF000000
  725.         end
  726.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  727.     end
  728.    
  729.     for n,neuron in pairs(network.neurons) do
  730.         cell = {}
  731.         if n > Inputs and n <MaxNodes> Inputs and gene.into <MaxNodes>= c2.x then
  732.                         c1.x = c1.x - 40
  733.                     end
  734.                     if c1.x <90> 220 then
  735.                         c1.x = 220
  736.                     end
  737.                     c1.y = 0.75*c1.y + 0.25*c2.y
  738.                    
  739.                 end
  740.                 if gene.out > Inputs and gene.out <MaxNodes>= c2.x then
  741.                         c2.x = c2.x + 40
  742.                     end
  743.                     if c2.x <90> 220 then
  744.                         c2.x = 220
  745.                     end
  746.                     c2.y = 0.25*c1.y + 0.75*c2.y
  747.                 end
  748.             end
  749.         end
  750.     end
  751.    
  752.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  753.     for n,cell in pairs(cells) do
  754.         if n > Inputs or cell.value ~= 0 then
  755.             local color = math.floor((cell.value+1)/2*256)
  756.             if color > 255 then color = 255 end
  757.             if color <0> 0 then
  758.                 color = opacity + 0x8000 + 0x10000*color
  759.             else
  760.                 color = opacity + 0x800000 + 0x100*color
  761.             end
  762.             gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  763.         end
  764.     end
  765.    
  766.     gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  767.    
  768.     if forms.ischecked(showMutationRates) then
  769.         local pos = 100
  770.         for mutation,rate in pairs(genome.mutationRates) do
  771.             gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  772.             pos = pos + 8
  773.         end
  774.     end
  775. end
  776.  
  777. function writeFile(filename)
  778.         local file = io.open(filename, "w")
  779.     file:write(pool.generation .. "\n")
  780.     file:write(pool.maxFitness .. "\n")
  781.     file:write(restart .. "\n")
  782.     file:write(#pool.species .. "\n")
  783.         for n,species in pairs(pool.species) do
  784.         file:write(species.topFitness .. "\n")
  785.         file:write(species.staleness .. "\n")
  786.         file:write(#species.genomes .. "\n")
  787.         for m,genome in pairs(species.genomes) do
  788.             file:write(genome.fitness .. "\n")
  789.             file:write(genome.maxneuron .. "\n")
  790.             for mutation,rate in pairs(genome.mutationRates) do
  791.                 file:write(mutation .. "\n")
  792.                 file:write(rate .. "\n")
  793.             end
  794.             file:write("done\n")
  795.            
  796.             file:write(#genome.genes .. "\n")
  797.             for l,gene in pairs(genome.genes) do
  798.                 file:write(gene.into .. " ")
  799.                 file:write(gene.out .. " ")
  800.                 file:write(gene.weight .. " ")
  801.                 file:write(gene.innovation .. " ")
  802.                 if(gene.enabled) then
  803.                     file:write("1\n")
  804.                 else
  805.                     file:write("0\n")
  806.                 end
  807.             end
  808.         end
  809.         end
  810.         file:close()
  811. end
  812.  
  813. function savePool()
  814.     local filename = forms.gettext(saveLoadFile)
  815.     writeFile(filename)
  816. end
  817.  
  818. function loadFile(filename)
  819.         local file = io.open(filename, "r")
  820.     pool = newPool()
  821.     pool.generation = file:read("*number")
  822.     pool.maxFitness = file:read("*number")
  823.     restart = file:read("*number")
  824.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  825.     forms.settext(restartLabel, "Restarts: " .. restart)
  826.         local numSpecies = file:read("*number")
  827.         for s=1,numSpecies do
  828.         local species = newSpecies()
  829.         table.insert(pool.species, species)
  830.         species.topFitness = file:read("*number")
  831.         species.staleness = file:read("*number")
  832.         local numGenomes = file:read("*number")
  833.         for g=1,numGenomes do
  834.             local genome = newGenome()
  835.             table.insert(species.genomes, genome)
  836.             genome.fitness = file:read("*number")
  837.             genome.maxneuron = file:read("*number")
  838.             local line = file:read("*line")
  839.             while line ~= "done" do
  840.                 genome.mutationRates[line] = file:read("*number")
  841.                 line = file:read("*line")
  842.             end
  843.             local numGenes = file:read("*number")
  844.             for n=1,numGenes do
  845.                 local gene = newGene()
  846.                 table.insert(genome.genes, gene)
  847.                 local enabled
  848.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  849.                 if enabled == 0 then
  850.                     gene.enabled = false
  851.                 else
  852.                     gene.enabled = true
  853.                 end
  854.                
  855.             end
  856.         end
  857.     end
  858.         file:close()
  859.    
  860.     while fitnessAlreadyMeasured() do
  861.         nextGenome()
  862.     end
  863.     initializeRun()
  864.     pool.currentFrame = pool.currentFrame + 1
  865. end
  866.  
  867. function loadPool()
  868.     local filename = forms.gettext(saveLoadFile)
  869.     loadFile(filename)
  870. end
  871.  
  872. function playTop()
  873.     local maxfitness = 0
  874.     local maxs, maxg
  875.     for s,species in pairs(pool.species) do
  876.         for g,genome in pairs(species.genomes) do
  877.             if genome.fitness > maxfitness then
  878.                 maxfitness = genome.fitness
  879.                 maxs = s
  880.                 maxg = g
  881.             end
  882.         end
  883.     end
  884.    
  885.     pool.currentSpecies = maxs
  886.     pool.currentGenome = maxg
  887.     pool.maxFitness = maxfitness
  888.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  889.     initializeRun()
  890.     pool.currentFrame = pool.currentFrame + 1
  891.     return
  892. end
  893.  
  894.  
  895. function onExit()
  896.     forms.destroy(form)
  897. end
  898.  
  899. writeFile("temp.pool")
  900.  
  901. event.onexit(onExit)
  902. event.onloadstate(countLoad)
  903.  
  904. form = forms.newform(200, 260, "Fitness")
  905. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  906. restartLabel = forms.label(form, "Restarts: " .. restart, 110, 8)
  907. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  908. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  909. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  910. saveButton = forms.button(form, "Save", savePool, 5, 102)
  911. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  912. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  913. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  914. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  915. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  916.  
  917.  
  918. while true do
  919.     local backgroundColor = 0xD0FFFFFF
  920.     if not forms.ischecked(hideBanner) then
  921.         gui.drawBox(0, 0, 300, 29, backgroundColor, backgroundColor)
  922.     end
  923.  
  924.     local species = pool.species[pool.currentSpecies]
  925.     local genome = species.genomes[pool.currentGenome]
  926.    
  927.     if forms.ischecked(showNetwork) then
  928.         displayGenome(genome)
  929.     end
  930.    
  931.     if pool.currentFrame%5 == 0 then
  932.         evaluateCurrent()
  933.     end
  934.  
  935.     joypad.set(controller)
  936.  
  937.     getPositions()
  938.     if marioX > rightmost then
  939.         rightmost = marioX
  940.         timeout = TimeoutConstant
  941.     end
  942.    
  943.     timeout = timeout - 1
  944.    
  945.     local timeoutBonus = pool.currentFrame / 4
  946.     if timeout + timeoutBonus <0> 4816 then
  947.             fitness = fitness + 1000
  948.         end
  949.         if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  950.             fitness = fitness + 1000
  951.         end
  952.         if fitness == 0 then
  953.             fitness = -1
  954.         end
  955.         genome.fitness = fitness
  956.        
  957.         if fitness > pool.maxFitness then
  958.             pool.maxFitness = fitness
  959.             forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  960.             writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  961.         end
  962.        
  963.         console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  964.         pool.currentSpecies = 1
  965.         pool.currentGenome = 1
  966.         while fitnessAlreadyMeasured() do
  967.             nextGenome()
  968.         end
  969.         initializeRun()
  970.     end
  971.  
  972.     local measured = 0
  973.     local total = 0
  974.     for _,species in pairs(pool.species) do
  975.         for _,genome in pairs(species.genomes) do
  976.             total = total + 1
  977.             if genome.fitness ~= 0 then
  978.                 measured = measured + 1
  979.             end
  980.         end
  981.     end
  982.     if not forms.ischecked(hideBanner) then
  983.         gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 10)
  984.         gui.drawText(0, 10, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 10)
  985.         gui.drawText(100, 10, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 10)
  986.         gui.drawText(0, 20, "Restarts: " .. restart, 0xFF000000, 10)
  987.     end
  988.        
  989.     pool.currentFrame = pool.currentFrame + 1
  990.  
  991.     emu.frameadvance();
  992. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement