Advertisement
Guest User

NESTEvolution w/restart counter saved in pool data

a guest
Jun 16th, 2015
282
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 29.17 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4.  
  5. if gameinfo.getromname() == "Super Mario World (USA)" then
  6.     Filename = "YI1.state"
  7.     ButtonNames = {
  8.         "A",
  9.         "B",
  10.         "X",
  11.         "Y",
  12.         "Up",
  13.         "Down",
  14.         "Left",
  15.         "Right",
  16.     }
  17. elseif gameinfo.getromname() == "Super Mario Bros." then
  18.     Filename = "SMB1-1.state"
  19.     ButtonNames = {
  20.         "A",
  21.         "B",
  22.         "Up",
  23.         "Down",
  24.         "Left",
  25.         "Right",
  26.     }
  27. end
  28.  
  29. BoxRadius = 6
  30. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  31.  
  32. Inputs = InputSize+1
  33. Outputs = #ButtonNames
  34.  
  35. Population = 300
  36. DeltaDisjoint = 2.0
  37. DeltaWeights = 0.4
  38. DeltaThreshold = 1.0
  39.  
  40. StaleSpecies = 15
  41.  
  42. MutateConnectionsChance = 0.25
  43. PerturbChance = 0.90
  44. CrossoverChance = 0.75
  45. LinkMutationChance = 2.0
  46. NodeMutationChance = 0.50
  47. BiasMutationChance = 0.40
  48. StepSize = 0.1
  49. DisableMutationChance = 0.4
  50. EnableMutationChance = 0.2
  51.  
  52. TimeoutConstant = 20
  53.  
  54. MaxNodes = 1000000
  55.  
  56. function countLoad()
  57.     pool.restart = pool.restart + 1
  58.     forms.settext(restartLabel, "Restarts: " .. pool.restart)
  59. end
  60.  
  61. function getPositions()
  62.     if gameinfo.getromname() == "Super Mario World (USA)" then
  63.         marioX = memory.read_s16_le(0x94)
  64.         marioY = memory.read_s16_le(0x96)
  65.        
  66.         local layer1x = memory.read_s16_le(0x1A);
  67.         local layer1y = memory.read_s16_le(0x1C);
  68.        
  69.         screenX = marioX-layer1x
  70.         screenY = marioY-layer1y
  71.     elseif gameinfo.getromname() == "Super Mario Bros." then
  72.         marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  73.         marioY = memory.readbyte(0x03B8)+16
  74.    
  75.         screenX = memory.readbyte(0x03AD)
  76.         screenY = memory.readbyte(0x03B8)
  77.     end
  78. end
  79.  
  80. function getTile(dx, dy)
  81.     if gameinfo.getromname() == "Super Mario World (USA)" then
  82.         x = math.floor((marioX+dx+8)/16)
  83.         y = math.floor((marioY+dy)/16)
  84.        
  85.         return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  86.     elseif gameinfo.getromname() == "Super Mario Bros." then
  87.         local x = marioX + dx + 8
  88.         local y = marioY + dy - 16
  89.         local page = math.floor(x/256)%2
  90.  
  91.         local subx = math.floor((x%256)/16)
  92.         local suby = math.floor((y - 32)/16)
  93.         local addr = 0x500 + page*13*16+suby*16+subx
  94.        
  95.         if suby >= 13 or suby < 0 then
  96.             return 0
  97.         end
  98.        
  99.         if memory.readbyte(addr) ~= 0 then
  100.             return 1
  101.         else
  102.             return 0
  103.         end
  104.     end
  105. end
  106.  
  107. function getSprites()
  108.     if gameinfo.getromname() == "Super Mario World (USA)" then
  109.         local sprites = {}
  110.         for slot=0,11 do
  111.             local status = memory.readbyte(0x14C8+slot)
  112.             if status ~= 0 then
  113.                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  114.                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  115.                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  116.             end
  117.         end    
  118.        
  119.         return sprites
  120.     elseif gameinfo.getromname() == "Super Mario Bros." then
  121.         local sprites = {}
  122.         for slot=0,4 do
  123.             local enemy = memory.readbyte(0xF+slot)
  124.             if enemy ~= 0 then
  125.                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  126.                 local ey = memory.readbyte(0xCF + slot)+24
  127.                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  128.             end
  129.         end
  130.        
  131.         return sprites
  132.     end
  133. end
  134.  
  135. function getExtendedSprites()
  136.     if gameinfo.getromname() == "Super Mario World (USA)" then
  137.         local extended = {}
  138.         for slot=0,11 do
  139.             local number = memory.readbyte(0x170B+slot)
  140.             if number ~= 0 then
  141.                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  142.                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  143.                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  144.             end
  145.         end    
  146.        
  147.         return extended
  148.     elseif gameinfo.getromname() == "Super Mario Bros." then
  149.         return {}
  150.     end
  151. end
  152.  
  153. function getInputs()
  154.     getPositions()
  155.    
  156.     sprites = getSprites()
  157.     extended = getExtendedSprites()
  158.    
  159.     local inputs = {}
  160.    
  161.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  162.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  163.             inputs[#inputs+1] = 0
  164.            
  165.             tile = getTile(dx, dy)
  166.             if tile == 1 and marioY+dy < 0x1B0 then
  167.                 inputs[#inputs] = 1
  168.             end
  169.            
  170.             for i = 1,#sprites do
  171.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  172.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  173.                 if distx <= 8 and disty <= 8 then
  174.                     inputs[#inputs] = -1
  175.                 end
  176.             end
  177.  
  178.             for i = 1,#extended do
  179.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  180.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  181.                 if distx < 8 and disty < 8 then
  182.                     inputs[#inputs] = -1
  183.                 end
  184.             end
  185.         end
  186.     end
  187.    
  188.     --mariovx = memory.read_s8(0x7B)
  189.     --mariovy = memory.read_s8(0x7D)
  190.    
  191.     return inputs
  192. end
  193.  
  194. function sigmoid(x)
  195.     return 2/(1+math.exp(-4.9*x))-1
  196. end
  197.  
  198. function newInnovation()
  199.     pool.innovation = pool.innovation + 1
  200.     return pool.innovation
  201. end
  202.  
  203. function newPool()
  204.     local pool = {}
  205.     pool.species = {}
  206.     pool.generation = 0
  207.     pool.innovation = Outputs
  208.     pool.currentSpecies = 1
  209.     pool.currentGenome = 1
  210.     pool.currentFrame = 0
  211.     pool.maxFitness = 0
  212.     pool.restart = 0
  213.    
  214.     return pool
  215. end
  216.  
  217. function newSpecies()
  218.     local species = {}
  219.     species.topFitness = 0
  220.     species.staleness = 0
  221.     species.genomes = {}
  222.     species.averageFitness = 0
  223.    
  224.     return species
  225. end
  226.  
  227. function newGenome()
  228.     local genome = {}
  229.     genome.genes = {}
  230.     genome.fitness = 0
  231.     genome.adjustedFitness = 0
  232.     genome.network = {}
  233.     genome.maxneuron = 0
  234.     genome.globalRank = 0
  235.     genome.mutationRates = {}
  236.     genome.mutationRates["connections"] = MutateConnectionsChance
  237.     genome.mutationRates["link"] = LinkMutationChance
  238.     genome.mutationRates["bias"] = BiasMutationChance
  239.     genome.mutationRates["node"] = NodeMutationChance
  240.     genome.mutationRates["enable"] = EnableMutationChance
  241.     genome.mutationRates["disable"] = DisableMutationChance
  242.     genome.mutationRates["step"] = StepSize
  243.    
  244.     return genome
  245. end
  246.  
  247. function copyGenome(genome)
  248.     local genome2 = newGenome()
  249.     for g=1,#genome.genes do
  250.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  251.     end
  252.     genome2.maxneuron = genome.maxneuron
  253.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  254.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  255.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  256.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  257.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  258.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  259.    
  260.     return genome2
  261. end
  262.  
  263. function basicGenome()
  264.     local genome = newGenome()
  265.     local innovation = 1
  266.  
  267.     genome.maxneuron = Inputs
  268.     mutate(genome)
  269.    
  270.     return genome
  271. end
  272.  
  273. function newGene()
  274.     local gene = {}
  275.     gene.into = 0
  276.     gene.out = 0
  277.     gene.weight = 0.0
  278.     gene.enabled = true
  279.     gene.innovation = 0
  280.    
  281.     return gene
  282. end
  283.  
  284. function copyGene(gene)
  285.     local gene2 = newGene()
  286.     gene2.into = gene.into
  287.     gene2.out = gene.out
  288.     gene2.weight = gene.weight
  289.     gene2.enabled = gene.enabled
  290.     gene2.innovation = gene.innovation
  291.    
  292.     return gene2
  293. end
  294.  
  295. function newNeuron()
  296.     local neuron = {}
  297.     neuron.incoming = {}
  298.     neuron.value = 0.0
  299.    
  300.     return neuron
  301. end
  302.  
  303. function generateNetwork(genome)
  304.     local network = {}
  305.     network.neurons = {}
  306.    
  307.     for i=1,Inputs do
  308.         network.neurons[i] = newNeuron()
  309.     end
  310.    
  311.     for o=1,Outputs do
  312.         network.neurons[MaxNodes+o] = newNeuron()
  313.     end
  314.    
  315.     table.sort(genome.genes, function (a,b)
  316.         return (a.out < b.out)
  317.     end)
  318.     for i=1,#genome.genes do
  319.         local gene = genome.genes[i]
  320.         if gene.enabled then
  321.             if network.neurons[gene.out] == nil then
  322.                 network.neurons[gene.out] = newNeuron()
  323.             end
  324.             local neuron = network.neurons[gene.out]
  325.             table.insert(neuron.incoming, gene)
  326.             if network.neurons[gene.into] == nil then
  327.                 network.neurons[gene.into] = newNeuron()
  328.             end
  329.         end
  330.     end
  331.    
  332.     genome.network = network
  333. end
  334.  
  335. function evaluateNetwork(network, inputs)
  336.     table.insert(inputs, 1)
  337.     if #inputs ~= Inputs then
  338.         console.writeline("Incorrect number of neural network inputs.")
  339.         return {}
  340.     end
  341.    
  342.     for i=1,Inputs do
  343.         network.neurons[i].value = inputs[i]
  344.     end
  345.    
  346.     for _,neuron in pairs(network.neurons) do
  347.         local sum = 0
  348.         for j = 1,#neuron.incoming do
  349.             local incoming = neuron.incoming[j]
  350.             local other = network.neurons[incoming.into]
  351.             sum = sum + incoming.weight * other.value
  352.         end
  353.        
  354.         if #neuron.incoming > 0 then
  355.             neuron.value = sigmoid(sum)
  356.         end
  357.     end
  358.    
  359.     local outputs = {}
  360.     for o=1,Outputs do
  361.         local button = "P1 " .. ButtonNames[o]
  362.         if network.neurons[MaxNodes+o].value > 0 then
  363.             outputs[button] = true
  364.         else
  365.             outputs[button] = false
  366.         end
  367.     end
  368.    
  369.     return outputs
  370. end
  371.  
  372. function crossover(g1, g2)
  373.     -- Make sure g1 is the higher fitness genome
  374.     if g2.fitness > g1.fitness then
  375.         tempg = g1
  376.         g1 = g2
  377.         g2 = tempg
  378.     end
  379.  
  380.     local child = newGenome()
  381.    
  382.     local innovations2 = {}
  383.     for i=1,#g2.genes do
  384.         local gene = g2.genes[i]
  385.         innovations2[gene.innovation] = gene
  386.     end
  387.    
  388.     for i=1,#g1.genes do
  389.         local gene1 = g1.genes[i]
  390.         local gene2 = innovations2[gene1.innovation]
  391.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  392.             table.insert(child.genes, copyGene(gene2))
  393.         else
  394.             table.insert(child.genes, copyGene(gene1))
  395.         end
  396.     end
  397.    
  398.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  399.    
  400.     for mutation,rate in pairs(g1.mutationRates) do
  401.         child.mutationRates[mutation] = rate
  402.     end
  403.    
  404.     return child
  405. end
  406.  
  407. function randomNeuron(genes, nonInput)
  408.     local neurons = {}
  409.     if not nonInput then
  410.         for i=1,Inputs do
  411.             neurons[i] = true
  412.         end
  413.     end
  414.     for o=1,Outputs do
  415.         neurons[MaxNodes+o] = true
  416.     end
  417.     for i=1,#genes do
  418.         if (not nonInput) or genes[i].into > Inputs then
  419.             neurons[genes[i].into] = true
  420.         end
  421.         if (not nonInput) or genes[i].out > Inputs then
  422.             neurons[genes[i].out] = true
  423.         end
  424.     end
  425.  
  426.     local count = 0
  427.     for _,_ in pairs(neurons) do
  428.         count = count + 1
  429.     end
  430.     local n = math.random(1, count)
  431.    
  432.     for k,v in pairs(neurons) do
  433.         n = n-1
  434.         if n == 0 then
  435.             return k
  436.         end
  437.     end
  438.    
  439.     return 0
  440. end
  441.  
  442. function containsLink(genes, link)
  443.     for i=1,#genes do
  444.         local gene = genes[i]
  445.         if gene.into == link.into and gene.out == link.out then
  446.             return true
  447.         end
  448.     end
  449. end
  450.  
  451. function pointMutate(genome)
  452.     local step = genome.mutationRates["step"]
  453.    
  454.     for i=1,#genome.genes do
  455.         local gene = genome.genes[i]
  456.         if math.random() < PerturbChance then
  457.             gene.weight = gene.weight + math.random() * step*2 - step
  458.         else
  459.             gene.weight = math.random()*4-2
  460.         end
  461.     end
  462. end
  463.  
  464. function linkMutate(genome, forceBias)
  465.     local neuron1 = randomNeuron(genome.genes, false)
  466.     local neuron2 = randomNeuron(genome.genes, true)
  467.      
  468.     local newLink = newGene()
  469.     if neuron1 <= Inputs and neuron2 <= Inputs then
  470.         --Both input nodes
  471.         return
  472.     end
  473.     if neuron2 <= Inputs then
  474.         -- Swap output and input
  475.         local temp = neuron1
  476.         neuron1 = neuron2
  477.         neuron2 = temp
  478.     end
  479.  
  480.     newLink.into = neuron1
  481.     newLink.out = neuron2
  482.     if forceBias then
  483.         newLink.into = Inputs
  484.     end
  485.    
  486.     if containsLink(genome.genes, newLink) then
  487.         return
  488.     end
  489.     newLink.innovation = newInnovation()
  490.     newLink.weight = math.random()*4-2
  491.    
  492.     table.insert(genome.genes, newLink)
  493. end
  494.  
  495. function nodeMutate(genome)
  496.     if #genome.genes == 0 then
  497.         return
  498.     end
  499.  
  500.     genome.maxneuron = genome.maxneuron + 1
  501.  
  502.     local gene = genome.genes[math.random(1,#genome.genes)]
  503.     if not gene.enabled then
  504.         return
  505.     end
  506.     gene.enabled = false
  507.    
  508.     local gene1 = copyGene(gene)
  509.     gene1.out = genome.maxneuron
  510.     gene1.weight = 1.0
  511.     gene1.innovation = newInnovation()
  512.     gene1.enabled = true
  513.     table.insert(genome.genes, gene1)
  514.    
  515.     local gene2 = copyGene(gene)
  516.     gene2.into = genome.maxneuron
  517.     gene2.innovation = newInnovation()
  518.     gene2.enabled = true
  519.     table.insert(genome.genes, gene2)
  520. end
  521.  
  522. function enableDisableMutate(genome, enable)
  523.     local candidates = {}
  524.     for _,gene in pairs(genome.genes) do
  525.         if gene.enabled == not enable then
  526.             table.insert(candidates, gene)
  527.         end
  528.     end
  529.    
  530.     if #candidates == 0 then
  531.         return
  532.     end
  533.    
  534.     local gene = candidates[math.random(1,#candidates)]
  535.     gene.enabled = not gene.enabled
  536. end
  537.  
  538. function mutate(genome)
  539.     for mutation,rate in pairs(genome.mutationRates) do
  540.         if math.random(1,2) == 1 then
  541.             genome.mutationRates[mutation] = 0.95*rate
  542.         else
  543.             genome.mutationRates[mutation] = 1.05263*rate
  544.         end
  545.     end
  546.  
  547.     if math.random() < genome.mutationRates["connections"] then
  548.         pointMutate(genome)
  549.     end
  550.    
  551.     local p = genome.mutationRates["link"]
  552.     while p > 0 do
  553.         if math.random() < p then
  554.             linkMutate(genome, false)
  555.         end
  556.         p = p - 1
  557.     end
  558.  
  559.     p = genome.mutationRates["bias"]
  560.     while p > 0 do
  561.         if math.random() < p then
  562.             linkMutate(genome, true)
  563.         end
  564.         p = p - 1
  565.     end
  566.    
  567.     p = genome.mutationRates["node"]
  568.     while p > 0 do
  569.         if math.random() < p then
  570.             nodeMutate(genome)
  571.         end
  572.         p = p - 1
  573.     end
  574.    
  575.     p = genome.mutationRates["enable"]
  576.     while p > 0 do
  577.         if math.random() < p then
  578.             enableDisableMutate(genome, true)
  579.         end
  580.         p = p - 1
  581.     end
  582.  
  583.     p = genome.mutationRates["disable"]
  584.     while p > 0 do
  585.         if math.random() < p then
  586.             enableDisableMutate(genome, false)
  587.         end
  588.         p = p - 1
  589.     end
  590. end
  591.  
  592. function disjoint(genes1, genes2)
  593.     local i1 = {}
  594.     for i = 1,#genes1 do
  595.         local gene = genes1[i]
  596.         i1[gene.innovation] = true
  597.     end
  598.  
  599.     local i2 = {}
  600.     for i = 1,#genes2 do
  601.         local gene = genes2[i]
  602.         i2[gene.innovation] = true
  603.     end
  604.    
  605.     local disjointGenes = 0
  606.     for i = 1,#genes1 do
  607.         local gene = genes1[i]
  608.         if not i2[gene.innovation] then
  609.             disjointGenes = disjointGenes+1
  610.         end
  611.     end
  612.    
  613.     for i = 1,#genes2 do
  614.         local gene = genes2[i]
  615.         if not i1[gene.innovation] then
  616.             disjointGenes = disjointGenes+1
  617.         end
  618.     end
  619.    
  620.     local n = math.max(#genes1, #genes2)
  621.    
  622.     return disjointGenes / n
  623. end
  624.  
  625. function weights(genes1, genes2)
  626.     local i2 = {}
  627.     for i = 1,#genes2 do
  628.         local gene = genes2[i]
  629.         i2[gene.innovation] = gene
  630.     end
  631.  
  632.     local sum = 0
  633.     local coincident = 0
  634.     for i = 1,#genes1 do
  635.         local gene = genes1[i]
  636.         if i2[gene.innovation] ~= nil then
  637.             local gene2 = i2[gene.innovation]
  638.             sum = sum + math.abs(gene.weight - gene2.weight)
  639.             coincident = coincident + 1
  640.         end
  641.     end
  642.    
  643.     return sum / coincident
  644. end
  645.    
  646. function sameSpecies(genome1, genome2)
  647.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  648.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  649.     return dd + dw < DeltaThreshold
  650. end
  651.  
  652. function rankGlobally()
  653.     local global = {}
  654.     for s = 1,#pool.species do
  655.         local species = pool.species[s]
  656.         for g = 1,#species.genomes do
  657.             table.insert(global, species.genomes[g])
  658.         end
  659.     end
  660.     table.sort(global, function (a,b)
  661.         return (a.fitness < b.fitness)
  662.     end)
  663.    
  664.     for g=1,#global do
  665.         global[g].globalRank = g
  666.     end
  667. end
  668.  
  669. function calculateAverageFitness(species)
  670.     local total = 0
  671.    
  672.     for g=1,#species.genomes do
  673.         local genome = species.genomes[g]
  674.         total = total + genome.globalRank
  675.     end
  676.    
  677.     species.averageFitness = total / #species.genomes
  678. end
  679.  
  680. function totalAverageFitness()
  681.     local total = 0
  682.     for s = 1,#pool.species do
  683.         local species = pool.species[s]
  684.         total = total + species.averageFitness
  685.     end
  686.  
  687.     return total
  688. end
  689.  
  690. function cullSpecies(cutToOne)
  691.     for s = 1,#pool.species do
  692.         local species = pool.species[s]
  693.        
  694.         table.sort(species.genomes, function (a,b)
  695.             return (a.fitness > b.fitness)
  696.         end)
  697.        
  698.         local remaining = math.ceil(#species.genomes/2)
  699.         if cutToOne then
  700.             remaining = 1
  701.         end
  702.         while #species.genomes > remaining do
  703.             table.remove(species.genomes)
  704.         end
  705.     end
  706. end
  707.  
  708. function breedChild(species)
  709.     local child = {}
  710.     if math.random() < CrossoverChance then
  711.         g1 = species.genomes[math.random(1, #species.genomes)]
  712.         g2 = species.genomes[math.random(1, #species.genomes)]
  713.         child = crossover(g1, g2)
  714.     else
  715.         g = species.genomes[math.random(1, #species.genomes)]
  716.         child = copyGenome(g)
  717.     end
  718.    
  719.     mutate(child)
  720.    
  721.     return child
  722. end
  723.  
  724. function removeStaleSpecies()
  725.     local survived = {}
  726.  
  727.     for s = 1,#pool.species do
  728.         local species = pool.species[s]
  729.        
  730.         table.sort(species.genomes, function (a,b)
  731.             return (a.fitness > b.fitness)
  732.         end)
  733.        
  734.         if species.genomes[1].fitness > species.topFitness then
  735.             species.topFitness = species.genomes[1].fitness
  736.             species.staleness = 0
  737.         else
  738.             species.staleness = species.staleness + 1
  739.         end
  740.         if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  741.             table.insert(survived, species)
  742.         end
  743.     end
  744.  
  745.     pool.species = survived
  746. end
  747.  
  748. function removeWeakSpecies()
  749.     local survived = {}
  750.  
  751.     local sum = totalAverageFitness()
  752.     for s = 1,#pool.species do
  753.         local species = pool.species[s]
  754.         breed = math.floor(species.averageFitness / sum * Population)
  755.         if breed >= 1 then
  756.             table.insert(survived, species)
  757.         end
  758.     end
  759.  
  760.     pool.species = survived
  761. end
  762.  
  763.  
  764. function addToSpecies(child)
  765.     local foundSpecies = false
  766.     for s=1,#pool.species do
  767.         local species = pool.species[s]
  768.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  769.             table.insert(species.genomes, child)
  770.             foundSpecies = true
  771.         end
  772.     end
  773.    
  774.     if not foundSpecies then
  775.         local childSpecies = newSpecies()
  776.         table.insert(childSpecies.genomes, child)
  777.         table.insert(pool.species, childSpecies)
  778.     end
  779. end
  780.  
  781. function newGeneration()
  782.     cullSpecies(false) -- Cull the bottom half of each species
  783.     rankGlobally()
  784.     removeStaleSpecies()
  785.     rankGlobally()
  786.     for s = 1,#pool.species do
  787.         local species = pool.species[s]
  788.         calculateAverageFitness(species)
  789.     end
  790.     removeWeakSpecies()
  791.     local sum = totalAverageFitness()
  792.     local children = {}
  793.     for s = 1,#pool.species do
  794.         local species = pool.species[s]
  795.         breed = math.floor(species.averageFitness / sum * Population) - 1
  796.         for i=1,breed do
  797.             table.insert(children, breedChild(species))
  798.         end
  799.     end
  800.     cullSpecies(true) -- Cull all but the top member of each species
  801.     while #children + #pool.species < Population do
  802.         local species = pool.species[math.random(1, #pool.species)]
  803.         table.insert(children, breedChild(species))
  804.     end
  805.     for c=1,#children do
  806.         local child = children[c]
  807.         addToSpecies(child)
  808.     end
  809.    
  810.     pool.generation = pool.generation + 1
  811.    
  812.     writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  813. end
  814.    
  815. function initializePool()
  816.     pool = newPool()
  817.  
  818.     for i=1,Population do
  819.         basic = basicGenome()
  820.         addToSpecies(basic)
  821.     end
  822.  
  823.     initializeRun()
  824. end
  825.  
  826. function clearJoypad()
  827.     controller = {}
  828.     for b = 1,#ButtonNames do
  829.         controller["P1 " .. ButtonNames[b]] = false
  830.     end
  831.     joypad.set(controller)
  832. end
  833.  
  834. function initializeRun()
  835.     savestate.load(Filename);
  836.     rightmost = 0
  837.     pool.currentFrame = 0
  838.     timeout = TimeoutConstant
  839.     clearJoypad()
  840.    
  841.     local species = pool.species[pool.currentSpecies]
  842.     local genome = species.genomes[pool.currentGenome]
  843.     generateNetwork(genome)
  844.     evaluateCurrent()
  845. end
  846.  
  847. function evaluateCurrent()
  848.     local species = pool.species[pool.currentSpecies]
  849.     local genome = species.genomes[pool.currentGenome]
  850.  
  851.     inputs = getInputs()
  852.     controller = evaluateNetwork(genome.network, inputs)
  853.    
  854.     if controller["P1 Left"] and controller["P1 Right"] then
  855.         controller["P1 Left"] = false
  856.         controller["P1 Right"] = false
  857.     end
  858.     if controller["P1 Up"] and controller["P1 Down"] then
  859.         controller["P1 Up"] = false
  860.         controller["P1 Down"] = false
  861.     end
  862.  
  863.     joypad.set(controller)
  864. end
  865.  
  866. if pool == nil then
  867.     initializePool()
  868. end
  869.  
  870.  
  871. function nextGenome()
  872.     pool.currentGenome = pool.currentGenome + 1
  873.     if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  874.         pool.currentGenome = 1
  875.         pool.currentSpecies = pool.currentSpecies+1
  876.         if pool.currentSpecies > #pool.species then
  877.             newGeneration()
  878.             pool.currentSpecies = 1
  879.         end
  880.     end
  881. end
  882.  
  883. function fitnessAlreadyMeasured()
  884.     local species = pool.species[pool.currentSpecies]
  885.     local genome = species.genomes[pool.currentGenome]
  886.    
  887.     return genome.fitness ~= 0
  888. end
  889.  
  890. function displayGenome(genome)
  891.     local network = genome.network
  892.     local cells = {}
  893.     local i = 1
  894.     local cell = {}
  895.     for dy=-BoxRadius,BoxRadius do
  896.         for dx=-BoxRadius,BoxRadius do
  897.             cell = {}
  898.             cell.x = 50+5*dx
  899.             cell.y = 70+5*dy
  900.             cell.value = network.neurons[i].value
  901.             cells[i] = cell
  902.             i = i + 1
  903.         end
  904.     end
  905.     local biasCell = {}
  906.     biasCell.x = 80
  907.     biasCell.y = 110
  908.     biasCell.value = network.neurons[Inputs].value
  909.     cells[Inputs] = biasCell
  910.    
  911.     for o = 1,Outputs do
  912.         cell = {}
  913.         cell.x = 220
  914.         cell.y = 30 + 8 * o
  915.         cell.value = network.neurons[MaxNodes + o].value
  916.         cells[MaxNodes+o] = cell
  917.         local color
  918.         if cell.value > 0 then
  919.             color = 0xFF0000FF
  920.         else
  921.             color = 0xFF000000
  922.         end
  923.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  924.     end
  925.    
  926.     for n,neuron in pairs(network.neurons) do
  927.         cell = {}
  928.         if n > Inputs and n <= MaxNodes then
  929.             cell.x = 140
  930.             cell.y = 40
  931.             cell.value = neuron.value
  932.             cells[n] = cell
  933.         end
  934.     end
  935.    
  936.     for n=1,4 do
  937.         for _,gene in pairs(genome.genes) do
  938.             if gene.enabled then
  939.                 local c1 = cells[gene.into]
  940.                 local c2 = cells[gene.out]
  941.                 if gene.into > Inputs and gene.into <= MaxNodes then
  942.                     c1.x = 0.75*c1.x + 0.25*c2.x
  943.                     if c1.x >= c2.x then
  944.                         c1.x = c1.x - 40
  945.                     end
  946.                     if c1.x < 90 then
  947.                         c1.x = 90
  948.                     end
  949.                    
  950.                     if c1.x > 220 then
  951.                         c1.x = 220
  952.                     end
  953.                     c1.y = 0.75*c1.y + 0.25*c2.y
  954.                    
  955.                 end
  956.                 if gene.out > Inputs and gene.out <= MaxNodes then
  957.                     c2.x = 0.25*c1.x + 0.75*c2.x
  958.                     if c1.x >= c2.x then
  959.                         c2.x = c2.x + 40
  960.                     end
  961.                     if c2.x < 90 then
  962.                         c2.x = 90
  963.                     end
  964.                     if c2.x > 220 then
  965.                         c2.x = 220
  966.                     end
  967.                     c2.y = 0.25*c1.y + 0.75*c2.y
  968.                 end
  969.             end
  970.         end
  971.     end
  972.    
  973.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  974.     for n,cell in pairs(cells) do
  975.         if n > Inputs or cell.value ~= 0 then
  976.             local color = math.floor((cell.value+1)/2*256)
  977.             if color > 255 then color = 255 end
  978.             if color < 0 then color = 0 end
  979.             local opacity = 0xFF000000
  980.             if cell.value == 0 then
  981.                 opacity = 0x50000000
  982.             end
  983.             color = opacity + color*0x10000 + color*0x100 + color
  984.             gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  985.         end
  986.     end
  987.     for _,gene in pairs(genome.genes) do
  988.         if gene.enabled then
  989.             local c1 = cells[gene.into]
  990.             local c2 = cells[gene.out]
  991.             local opacity = 0xA0000000
  992.             if c1.value == 0 then
  993.                 opacity = 0x20000000
  994.             end
  995.            
  996.             local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  997.             if gene.weight > 0 then
  998.                 color = opacity + 0x8000 + 0x10000*color
  999.             else
  1000.                 color = opacity + 0x800000 + 0x100*color
  1001.             end
  1002.             gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  1003.         end
  1004.     end
  1005.    
  1006.     gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1007.    
  1008.     if forms.ischecked(showMutationRates) then
  1009.         local pos = 100
  1010.         for mutation,rate in pairs(genome.mutationRates) do
  1011.             gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1012.             pos = pos + 8
  1013.         end
  1014.     end
  1015. end
  1016.  
  1017. function writeFile(filename)
  1018.         local file = io.open(filename, "w")
  1019.     file:write(pool.generation .. "\n")
  1020.     file:write(pool.maxFitness .. "\n")
  1021.     file:write(pool.restart .. "\n")
  1022.     file:write(#pool.species .. "\n")
  1023.         for n,species in pairs(pool.species) do
  1024.         file:write(species.topFitness .. "\n")
  1025.         file:write(species.staleness .. "\n")
  1026.         file:write(#species.genomes .. "\n")
  1027.         for m,genome in pairs(species.genomes) do
  1028.             file:write(genome.fitness .. "\n")
  1029.             file:write(genome.maxneuron .. "\n")
  1030.             for mutation,rate in pairs(genome.mutationRates) do
  1031.                 file:write(mutation .. "\n")
  1032.                 file:write(rate .. "\n")
  1033.             end
  1034.             file:write("done\n")
  1035.            
  1036.             file:write(#genome.genes .. "\n")
  1037.             for l,gene in pairs(genome.genes) do
  1038.                 file:write(gene.into .. " ")
  1039.                 file:write(gene.out .. " ")
  1040.                 file:write(gene.weight .. " ")
  1041.                 file:write(gene.innovation .. " ")
  1042.                 if(gene.enabled) then
  1043.                     file:write("1\n")
  1044.                 else
  1045.                     file:write("0\n")
  1046.                 end
  1047.             end
  1048.         end
  1049.         end
  1050.         file:close()
  1051. end
  1052.  
  1053. function savePool()
  1054.     local filename = forms.gettext(saveLoadFile)
  1055.     writeFile(filename)
  1056. end
  1057.  
  1058. function loadFile(filename)
  1059.         local file = io.open(filename, "r")
  1060.     pool = newPool()
  1061.     pool.generation = file:read("*number")
  1062.     pool.maxFitness = file:read("*number")
  1063.     pool.restart = file:read("*number")
  1064.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1065.     forms.settext(restartLabel, "Restarts: " .. pool.restart)
  1066.         local numSpecies = file:read("*number")
  1067.         for s=1,numSpecies do
  1068.         local species = newSpecies()
  1069.         table.insert(pool.species, species)
  1070.         species.topFitness = file:read("*number")
  1071.         species.staleness = file:read("*number")
  1072.         local numGenomes = file:read("*number")
  1073.         for g=1,numGenomes do
  1074.             local genome = newGenome()
  1075.             table.insert(species.genomes, genome)
  1076.             genome.fitness = file:read("*number")
  1077.             genome.maxneuron = file:read("*number")
  1078.             local line = file:read("*line")
  1079.             while line ~= "done" do
  1080.                 genome.mutationRates[line] = file:read("*number")
  1081.                 line = file:read("*line")
  1082.             end
  1083.             local numGenes = file:read("*number")
  1084.             for n=1,numGenes do
  1085.                 local gene = newGene()
  1086.                 table.insert(genome.genes, gene)
  1087.                 local enabled
  1088.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1089.                 if enabled == 0 then
  1090.                     gene.enabled = false
  1091.                 else
  1092.                     gene.enabled = true
  1093.                 end
  1094.                
  1095.             end
  1096.         end
  1097.     end
  1098.         file:close()
  1099.    
  1100.     while fitnessAlreadyMeasured() do
  1101.         nextGenome()
  1102.     end
  1103.     initializeRun()
  1104.     pool.currentFrame = pool.currentFrame + 1
  1105. end
  1106.  
  1107. function loadPool()
  1108.     local filename = forms.gettext(saveLoadFile)
  1109.     loadFile(filename)
  1110. end
  1111.  
  1112. function playTop()
  1113.     local maxfitness = 0
  1114.     local maxs, maxg
  1115.     for s,species in pairs(pool.species) do
  1116.         for g,genome in pairs(species.genomes) do
  1117.             if genome.fitness > maxfitness then
  1118.                 maxfitness = genome.fitness
  1119.                 maxs = s
  1120.                 maxg = g
  1121.             end
  1122.         end
  1123.     end
  1124.    
  1125.     pool.currentSpecies = maxs
  1126.     pool.currentGenome = maxg
  1127.     pool.maxFitness = maxfitness
  1128.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1129.     initializeRun()
  1130.     pool.currentFrame = pool.currentFrame + 1
  1131.     return
  1132. end
  1133.  
  1134.  
  1135. function onExit()
  1136.     forms.destroy(form)
  1137. end
  1138.  
  1139. writeFile("temp.pool")
  1140.  
  1141. event.onexit(onExit)
  1142. event.onloadstate(countLoad)
  1143.  
  1144. form = forms.newform(200, 260, "Fitness")
  1145. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1146. restartLabel = forms.label(form, "Restarts: " .. pool.restart, 110, 8)
  1147. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1148. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1149. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1150. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1151. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1152. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1153. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1154. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1155. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1156.  
  1157.  
  1158. while true do
  1159.     local backgroundColor = 0xD0FFFFFF
  1160.     if not forms.ischecked(hideBanner) then
  1161.         gui.drawBox(0, 0, 300, 32, backgroundColor, backgroundColor)
  1162.     end
  1163.    
  1164.     local species = pool.species[pool.currentSpecies]
  1165.     local genome = species.genomes[pool.currentGenome]
  1166.    
  1167.     if forms.ischecked(showNetwork) then
  1168.         displayGenome(genome)
  1169.     end
  1170.    
  1171.     if pool.currentFrame%5 == 0 then
  1172.         evaluateCurrent()
  1173.     end
  1174.  
  1175.     joypad.set(controller)
  1176.  
  1177.     getPositions()
  1178.     if marioX > rightmost then
  1179.         rightmost = marioX
  1180.         timeout = TimeoutConstant
  1181.     end
  1182.    
  1183.     timeout = timeout - 1
  1184.    
  1185.     local timeoutBonus = pool.currentFrame / 4
  1186.     if timeout + timeoutBonus <= 0 then
  1187.         local fitness = rightmost - pool.currentFrame / 2
  1188.         if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1189.             fitness = fitness + 1000
  1190.         end
  1191.         if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1192.             fitness = fitness + 1000
  1193.         end
  1194.         if fitness == 0 then
  1195.             fitness = -1
  1196.         end
  1197.         genome.fitness = fitness
  1198.        
  1199.         if fitness > pool.maxFitness then
  1200.             pool.maxFitness = fitness
  1201.             forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1202.             writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1203.         end
  1204.        
  1205.         console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1206.         pool.currentSpecies = 1
  1207.         pool.currentGenome = 1
  1208.         while fitnessAlreadyMeasured() do
  1209.             nextGenome()
  1210.         end
  1211.         initializeRun()
  1212.     end
  1213.  
  1214.     local measured = 0
  1215.     local total = 0
  1216.     for _,species in pairs(pool.species) do
  1217.         for _,genome in pairs(species.genomes) do
  1218.             total = total + 1
  1219.             if genome.fitness ~= 0 then
  1220.                 measured = measured + 1
  1221.             end
  1222.         end
  1223.     end
  1224.     if not forms.ischecked(hideBanner) then
  1225.         gui.drawText(0, 0, "Gen " .. pool.generation, 0xFF000000, 10, "Arial", "Bold")
  1226.         gui.drawText(50, 0, "Species " .. pool.currentSpecies, 0xFF000000, 10, "Arial", "Bold")
  1227.         gui.drawText(125, 0, "Genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 10, "Arial", "Bold")
  1228.         gui.drawText(0, 10, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 10, "Arial", "Bold")
  1229.         gui.drawText(100, 10, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 10, "Arial", "Bold")
  1230.         gui.drawText(0, 20, "Restarts: " .. pool.restart, 0xFF000000, 10, "Arial", "Bold")
  1231.     end
  1232.        
  1233.     pool.currentFrame = pool.currentFrame + 1
  1234.  
  1235.     emu.frameadvance();
  1236. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement