SHARE
TWEET

NEATEvolve.lua

SethBling Jun 13th, 2015 (edited) 186,262 Never
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8.         Filename = "DP1.state"
  9.         ButtonNames = {
  10.                 "A",
  11.                 "B",
  12.                 "X",
  13.                 "Y",
  14.                 "Up",
  15.                 "Down",
  16.                 "Left",
  17.                 "Right",
  18.         }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20.         Filename = "SMB1-1.state"
  21.         ButtonNames = {
  22.                 "A",
  23.                 "B",
  24.                 "Up",
  25.                 "Down",
  26.                 "Left",
  27.                 "Right",
  28.         }
  29. end
  30.  
  31. BoxRadius = 6
  32. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  33.  
  34. Inputs = InputSize+1
  35. Outputs = #ButtonNames
  36.  
  37. Population = 300
  38. DeltaDisjoint = 2.0
  39. DeltaWeights = 0.4
  40. DeltaThreshold = 1.0
  41.  
  42. StaleSpecies = 15
  43.  
  44. MutateConnectionsChance = 0.25
  45. PerturbChance = 0.90
  46. CrossoverChance = 0.75
  47. LinkMutationChance = 2.0
  48. NodeMutationChance = 0.50
  49. BiasMutationChance = 0.40
  50. StepSize = 0.1
  51. DisableMutationChance = 0.4
  52. EnableMutationChance = 0.2
  53.  
  54. TimeoutConstant = 20
  55.  
  56. MaxNodes = 1000000
  57.  
  58. function getPositions()
  59.         if gameinfo.getromname() == "Super Mario World (USA)" then
  60.                 marioX = memory.read_s16_le(0x94)
  61.                 marioY = memory.read_s16_le(0x96)
  62.                
  63.                 local layer1x = memory.read_s16_le(0x1A);
  64.                 local layer1y = memory.read_s16_le(0x1C);
  65.                
  66.                 screenX = marioX-layer1x
  67.                 screenY = marioY-layer1y
  68.         elseif gameinfo.getromname() == "Super Mario Bros." then
  69.                 marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  70.                 marioY = memory.readbyte(0x03B8)+16
  71.        
  72.                 screenX = memory.readbyte(0x03AD)
  73.                 screenY = memory.readbyte(0x03B8)
  74.         end
  75. end
  76.  
  77. function getTile(dx, dy)
  78.         if gameinfo.getromname() == "Super Mario World (USA)" then
  79.                 x = math.floor((marioX+dx+8)/16)
  80.                 y = math.floor((marioY+dy)/16)
  81.                
  82.                 return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  83.         elseif gameinfo.getromname() == "Super Mario Bros." then
  84.                 local x = marioX + dx + 8
  85.                 local y = marioY + dy - 16
  86.                 local page = math.floor(x/256)%2
  87.  
  88.                 local subx = math.floor((x%256)/16)
  89.                 local suby = math.floor((y - 32)/16)
  90.                 local addr = 0x500 + page*13*16+suby*16+subx
  91.                
  92.                 if suby >= 13 or suby < 0 then
  93.                         return 0
  94.                 end
  95.                
  96.                 if memory.readbyte(addr) ~= 0 then
  97.                         return 1
  98.                 else
  99.                         return 0
  100.                 end
  101.         end
  102. end
  103.  
  104. function getSprites()
  105.         if gameinfo.getromname() == "Super Mario World (USA)" then
  106.                 local sprites = {}
  107.                 for slot=0,11 do
  108.                         local status = memory.readbyte(0x14C8+slot)
  109.                         if status ~= 0 then
  110.                                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  111.                                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  112.                                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  113.                         end
  114.                 end            
  115.                
  116.                 return sprites
  117.         elseif gameinfo.getromname() == "Super Mario Bros." then
  118.                 local sprites = {}
  119.                 for slot=0,4 do
  120.                         local enemy = memory.readbyte(0xF+slot)
  121.                         if enemy ~= 0 then
  122.                                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  123.                                 local ey = memory.readbyte(0xCF + slot)+24
  124.                                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  125.                         end
  126.                 end
  127.                
  128.                 return sprites
  129.         end
  130. end
  131.  
  132. function getExtendedSprites()
  133.         if gameinfo.getromname() == "Super Mario World (USA)" then
  134.                 local extended = {}
  135.                 for slot=0,11 do
  136.                         local number = memory.readbyte(0x170B+slot)
  137.                         if number ~= 0 then
  138.                                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  139.                                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  140.                                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  141.                         end
  142.                 end            
  143.                
  144.                 return extended
  145.         elseif gameinfo.getromname() == "Super Mario Bros." then
  146.                 return {}
  147.         end
  148. end
  149.  
  150. function getInputs()
  151.         getPositions()
  152.        
  153.         sprites = getSprites()
  154.         extended = getExtendedSprites()
  155.        
  156.         local inputs = {}
  157.        
  158.         for dy=-BoxRadius*16,BoxRadius*16,16 do
  159.                 for dx=-BoxRadius*16,BoxRadius*16,16 do
  160.                         inputs[#inputs+1] = 0
  161.                        
  162.                         tile = getTile(dx, dy)
  163.                         if tile == 1 and marioY+dy < 0x1B0 then
  164.                                 inputs[#inputs] = 1
  165.                         end
  166.                        
  167.                         for i = 1,#sprites do
  168.                                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  169.                                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  170.                                 if distx <= 8 and disty <= 8 then
  171.                                         inputs[#inputs] = -1
  172.                                 end
  173.                         end
  174.  
  175.                         for i = 1,#extended do
  176.                                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  177.                                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  178.                                 if distx < 8 and disty < 8 then
  179.                                         inputs[#inputs] = -1
  180.                                 end
  181.                         end
  182.                 end
  183.         end
  184.        
  185.         --mariovx = memory.read_s8(0x7B)
  186.         --mariovy = memory.read_s8(0x7D)
  187.        
  188.         return inputs
  189. end
  190.  
  191. function sigmoid(x)
  192.         return 2/(1+math.exp(-4.9*x))-1
  193. end
  194.  
  195. function newInnovation()
  196.         pool.innovation = pool.innovation + 1
  197.         return pool.innovation
  198. end
  199.  
  200. function newPool()
  201.         local pool = {}
  202.         pool.species = {}
  203.         pool.generation = 0
  204.         pool.innovation = Outputs
  205.         pool.currentSpecies = 1
  206.         pool.currentGenome = 1
  207.         pool.currentFrame = 0
  208.         pool.maxFitness = 0
  209.        
  210.         return pool
  211. end
  212.  
  213. function newSpecies()
  214.         local species = {}
  215.         species.topFitness = 0
  216.         species.staleness = 0
  217.         species.genomes = {}
  218.         species.averageFitness = 0
  219.        
  220.         return species
  221. end
  222.  
  223. function newGenome()
  224.         local genome = {}
  225.         genome.genes = {}
  226.         genome.fitness = 0
  227.         genome.adjustedFitness = 0
  228.         genome.network = {}
  229.         genome.maxneuron = 0
  230.         genome.globalRank = 0
  231.         genome.mutationRates = {}
  232.         genome.mutationRates["connections"] = MutateConnectionsChance
  233.         genome.mutationRates["link"] = LinkMutationChance
  234.         genome.mutationRates["bias"] = BiasMutationChance
  235.         genome.mutationRates["node"] = NodeMutationChance
  236.         genome.mutationRates["enable"] = EnableMutationChance
  237.         genome.mutationRates["disable"] = DisableMutationChance
  238.         genome.mutationRates["step"] = StepSize
  239.        
  240.         return genome
  241. end
  242.  
  243. function copyGenome(genome)
  244.         local genome2 = newGenome()
  245.         for g=1,#genome.genes do
  246.                 table.insert(genome2.genes, copyGene(genome.genes[g]))
  247.         end
  248.         genome2.maxneuron = genome.maxneuron
  249.         genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  250.         genome2.mutationRates["link"] = genome.mutationRates["link"]
  251.         genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  252.         genome2.mutationRates["node"] = genome.mutationRates["node"]
  253.         genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  254.         genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  255.        
  256.         return genome2
  257. end
  258.  
  259. function basicGenome()
  260.         local genome = newGenome()
  261.         local innovation = 1
  262.  
  263.         genome.maxneuron = Inputs
  264.         mutate(genome)
  265.        
  266.         return genome
  267. end
  268.  
  269. function newGene()
  270.         local gene = {}
  271.         gene.into = 0
  272.         gene.out = 0
  273.         gene.weight = 0.0
  274.         gene.enabled = true
  275.         gene.innovation = 0
  276.        
  277.         return gene
  278. end
  279.  
  280. function copyGene(gene)
  281.         local gene2 = newGene()
  282.         gene2.into = gene.into
  283.         gene2.out = gene.out
  284.         gene2.weight = gene.weight
  285.         gene2.enabled = gene.enabled
  286.         gene2.innovation = gene.innovation
  287.        
  288.         return gene2
  289. end
  290.  
  291. function newNeuron()
  292.         local neuron = {}
  293.         neuron.incoming = {}
  294.         neuron.value = 0.0
  295.        
  296.         return neuron
  297. end
  298.  
  299. function generateNetwork(genome)
  300.         local network = {}
  301.         network.neurons = {}
  302.        
  303.         for i=1,Inputs do
  304.                 network.neurons[i] = newNeuron()
  305.         end
  306.        
  307.         for o=1,Outputs do
  308.                 network.neurons[MaxNodes+o] = newNeuron()
  309.         end
  310.        
  311.         table.sort(genome.genes, function (a,b)
  312.                 return (a.out < b.out)
  313.         end)
  314.         for i=1,#genome.genes do
  315.                 local gene = genome.genes[i]
  316.                 if gene.enabled then
  317.                         if network.neurons[gene.out] == nil then
  318.                                 network.neurons[gene.out] = newNeuron()
  319.                         end
  320.                         local neuron = network.neurons[gene.out]
  321.                         table.insert(neuron.incoming, gene)
  322.                         if network.neurons[gene.into] == nil then
  323.                                 network.neurons[gene.into] = newNeuron()
  324.                         end
  325.                 end
  326.         end
  327.        
  328.         genome.network = network
  329. end
  330.  
  331. function evaluateNetwork(network, inputs)
  332.         table.insert(inputs, 1)
  333.         if #inputs ~= Inputs then
  334.                 console.writeline("Incorrect number of neural network inputs.")
  335.                 return {}
  336.         end
  337.        
  338.         for i=1,Inputs do
  339.                 network.neurons[i].value = inputs[i]
  340.         end
  341.        
  342.         for _,neuron in pairs(network.neurons) do
  343.                 local sum = 0
  344.                 for j = 1,#neuron.incoming do
  345.                         local incoming = neuron.incoming[j]
  346.                         local other = network.neurons[incoming.into]
  347.                         sum = sum + incoming.weight * other.value
  348.                 end
  349.                
  350.                 if #neuron.incoming > 0 then
  351.                         neuron.value = sigmoid(sum)
  352.                 end
  353.         end
  354.        
  355.         local outputs = {}
  356.         for o=1,Outputs do
  357.                 local button = "P1 " .. ButtonNames[o]
  358.                 if network.neurons[MaxNodes+o].value > 0 then
  359.                         outputs[button] = true
  360.                 else
  361.                         outputs[button] = false
  362.                 end
  363.         end
  364.        
  365.         return outputs
  366. end
  367.  
  368. function crossover(g1, g2)
  369.         -- Make sure g1 is the higher fitness genome
  370.         if g2.fitness > g1.fitness then
  371.                 tempg = g1
  372.                 g1 = g2
  373.                 g2 = tempg
  374.         end
  375.  
  376.         local child = newGenome()
  377.        
  378.         local innovations2 = {}
  379.         for i=1,#g2.genes do
  380.                 local gene = g2.genes[i]
  381.                 innovations2[gene.innovation] = gene
  382.         end
  383.        
  384.         for i=1,#g1.genes do
  385.                 local gene1 = g1.genes[i]
  386.                 local gene2 = innovations2[gene1.innovation]
  387.                 if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  388.                         table.insert(child.genes, copyGene(gene2))
  389.                 else
  390.                         table.insert(child.genes, copyGene(gene1))
  391.                 end
  392.         end
  393.        
  394.         child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  395.        
  396.         for mutation,rate in pairs(g1.mutationRates) do
  397.                 child.mutationRates[mutation] = rate
  398.         end
  399.        
  400.         return child
  401. end
  402.  
  403. function randomNeuron(genes, nonInput)
  404.         local neurons = {}
  405.         if not nonInput then
  406.                 for i=1,Inputs do
  407.                         neurons[i] = true
  408.                 end
  409.         end
  410.         for o=1,Outputs do
  411.                 neurons[MaxNodes+o] = true
  412.         end
  413.         for i=1,#genes do
  414.                 if (not nonInput) or genes[i].into > Inputs then
  415.                         neurons[genes[i].into] = true
  416.                 end
  417.                 if (not nonInput) or genes[i].out > Inputs then
  418.                         neurons[genes[i].out] = true
  419.                 end
  420.         end
  421.  
  422.         local count = 0
  423.         for _,_ in pairs(neurons) do
  424.                 count = count + 1
  425.         end
  426.         local n = math.random(1, count)
  427.        
  428.         for k,v in pairs(neurons) do
  429.                 n = n-1
  430.                 if n == 0 then
  431.                         return k
  432.                 end
  433.         end
  434.        
  435.         return 0
  436. end
  437.  
  438. function containsLink(genes, link)
  439.         for i=1,#genes do
  440.                 local gene = genes[i]
  441.                 if gene.into == link.into and gene.out == link.out then
  442.                         return true
  443.                 end
  444.         end
  445. end
  446.  
  447. function pointMutate(genome)
  448.         local step = genome.mutationRates["step"]
  449.        
  450.         for i=1,#genome.genes do
  451.                 local gene = genome.genes[i]
  452.                 if math.random() < PerturbChance then
  453.                         gene.weight = gene.weight + math.random() * step*2 - step
  454.                 else
  455.                         gene.weight = math.random()*4-2
  456.                 end
  457.         end
  458. end
  459.  
  460. function linkMutate(genome, forceBias)
  461.         local neuron1 = randomNeuron(genome.genes, false)
  462.         local neuron2 = randomNeuron(genome.genes, true)
  463.          
  464.         local newLink = newGene()
  465.         if neuron1 <= Inputs and neuron2 <= Inputs then
  466.                 --Both input nodes
  467.                 return
  468.         end
  469.         if neuron2 <= Inputs then
  470.                 -- Swap output and input
  471.                 local temp = neuron1
  472.                 neuron1 = neuron2
  473.                 neuron2 = temp
  474.         end
  475.  
  476.         newLink.into = neuron1
  477.         newLink.out = neuron2
  478.         if forceBias then
  479.                 newLink.into = Inputs
  480.         end
  481.        
  482.         if containsLink(genome.genes, newLink) then
  483.                 return
  484.         end
  485.         newLink.innovation = newInnovation()
  486.         newLink.weight = math.random()*4-2
  487.        
  488.         table.insert(genome.genes, newLink)
  489. end
  490.  
  491. function nodeMutate(genome)
  492.         if #genome.genes == 0 then
  493.                 return
  494.         end
  495.  
  496.         genome.maxneuron = genome.maxneuron + 1
  497.  
  498.         local gene = genome.genes[math.random(1,#genome.genes)]
  499.         if not gene.enabled then
  500.                 return
  501.         end
  502.         gene.enabled = false
  503.        
  504.         local gene1 = copyGene(gene)
  505.         gene1.out = genome.maxneuron
  506.         gene1.weight = 1.0
  507.         gene1.innovation = newInnovation()
  508.         gene1.enabled = true
  509.         table.insert(genome.genes, gene1)
  510.        
  511.         local gene2 = copyGene(gene)
  512.         gene2.into = genome.maxneuron
  513.         gene2.innovation = newInnovation()
  514.         gene2.enabled = true
  515.         table.insert(genome.genes, gene2)
  516. end
  517.  
  518. function enableDisableMutate(genome, enable)
  519.         local candidates = {}
  520.         for _,gene in pairs(genome.genes) do
  521.                 if gene.enabled == not enable then
  522.                         table.insert(candidates, gene)
  523.                 end
  524.         end
  525.        
  526.         if #candidates == 0 then
  527.                 return
  528.         end
  529.        
  530.         local gene = candidates[math.random(1,#candidates)]
  531.         gene.enabled = not gene.enabled
  532. end
  533.  
  534. function mutate(genome)
  535.         for mutation,rate in pairs(genome.mutationRates) do
  536.                 if math.random(1,2) == 1 then
  537.                         genome.mutationRates[mutation] = 0.95*rate
  538.                 else
  539.                         genome.mutationRates[mutation] = 1.05263*rate
  540.                 end
  541.         end
  542.  
  543.         if math.random() < genome.mutationRates["connections"] then
  544.                 pointMutate(genome)
  545.         end
  546.        
  547.         local p = genome.mutationRates["link"]
  548.         while p > 0 do
  549.                 if math.random() < p then
  550.                         linkMutate(genome, false)
  551.                 end
  552.                 p = p - 1
  553.         end
  554.  
  555.         p = genome.mutationRates["bias"]
  556.         while p > 0 do
  557.                 if math.random() < p then
  558.                         linkMutate(genome, true)
  559.                 end
  560.                 p = p - 1
  561.         end
  562.        
  563.         p = genome.mutationRates["node"]
  564.         while p > 0 do
  565.                 if math.random() < p then
  566.                         nodeMutate(genome)
  567.                 end
  568.                 p = p - 1
  569.         end
  570.        
  571.         p = genome.mutationRates["enable"]
  572.         while p > 0 do
  573.                 if math.random() < p then
  574.                         enableDisableMutate(genome, true)
  575.                 end
  576.                 p = p - 1
  577.         end
  578.  
  579.         p = genome.mutationRates["disable"]
  580.         while p > 0 do
  581.                 if math.random() < p then
  582.                         enableDisableMutate(genome, false)
  583.                 end
  584.                 p = p - 1
  585.         end
  586. end
  587.  
  588. function disjoint(genes1, genes2)
  589.         local i1 = {}
  590.         for i = 1,#genes1 do
  591.                 local gene = genes1[i]
  592.                 i1[gene.innovation] = true
  593.         end
  594.  
  595.         local i2 = {}
  596.         for i = 1,#genes2 do
  597.                 local gene = genes2[i]
  598.                 i2[gene.innovation] = true
  599.         end
  600.        
  601.         local disjointGenes = 0
  602.         for i = 1,#genes1 do
  603.                 local gene = genes1[i]
  604.                 if not i2[gene.innovation] then
  605.                         disjointGenes = disjointGenes+1
  606.                 end
  607.         end
  608.        
  609.         for i = 1,#genes2 do
  610.                 local gene = genes2[i]
  611.                 if not i1[gene.innovation] then
  612.                         disjointGenes = disjointGenes+1
  613.                 end
  614.         end
  615.        
  616.         local n = math.max(#genes1, #genes2)
  617.        
  618.         return disjointGenes / n
  619. end
  620.  
  621. function weights(genes1, genes2)
  622.         local i2 = {}
  623.         for i = 1,#genes2 do
  624.                 local gene = genes2[i]
  625.                 i2[gene.innovation] = gene
  626.         end
  627.  
  628.         local sum = 0
  629.         local coincident = 0
  630.         for i = 1,#genes1 do
  631.                 local gene = genes1[i]
  632.                 if i2[gene.innovation] ~= nil then
  633.                         local gene2 = i2[gene.innovation]
  634.                         sum = sum + math.abs(gene.weight - gene2.weight)
  635.                         coincident = coincident + 1
  636.                 end
  637.         end
  638.        
  639.         return sum / coincident
  640. end
  641.        
  642. function sameSpecies(genome1, genome2)
  643.         local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  644.         local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  645.         return dd + dw < DeltaThreshold
  646. end
  647.  
  648. function rankGlobally()
  649.         local global = {}
  650.         for s = 1,#pool.species do
  651.                 local species = pool.species[s]
  652.                 for g = 1,#species.genomes do
  653.                         table.insert(global, species.genomes[g])
  654.                 end
  655.         end
  656.         table.sort(global, function (a,b)
  657.                 return (a.fitness < b.fitness)
  658.         end)
  659.        
  660.         for g=1,#global do
  661.                 global[g].globalRank = g
  662.         end
  663. end
  664.  
  665. function calculateAverageFitness(species)
  666.         local total = 0
  667.        
  668.         for g=1,#species.genomes do
  669.                 local genome = species.genomes[g]
  670.                 total = total + genome.globalRank
  671.         end
  672.        
  673.         species.averageFitness = total / #species.genomes
  674. end
  675.  
  676. function totalAverageFitness()
  677.         local total = 0
  678.         for s = 1,#pool.species do
  679.                 local species = pool.species[s]
  680.                 total = total + species.averageFitness
  681.         end
  682.  
  683.         return total
  684. end
  685.  
  686. function cullSpecies(cutToOne)
  687.         for s = 1,#pool.species do
  688.                 local species = pool.species[s]
  689.                
  690.                 table.sort(species.genomes, function (a,b)
  691.                         return (a.fitness > b.fitness)
  692.                 end)
  693.                
  694.                 local remaining = math.ceil(#species.genomes/2)
  695.                 if cutToOne then
  696.                         remaining = 1
  697.                 end
  698.                 while #species.genomes > remaining do
  699.                         table.remove(species.genomes)
  700.                 end
  701.         end
  702. end
  703.  
  704. function breedChild(species)
  705.         local child = {}
  706.         if math.random() < CrossoverChance then
  707.                 g1 = species.genomes[math.random(1, #species.genomes)]
  708.                 g2 = species.genomes[math.random(1, #species.genomes)]
  709.                 child = crossover(g1, g2)
  710.         else
  711.                 g = species.genomes[math.random(1, #species.genomes)]
  712.                 child = copyGenome(g)
  713.         end
  714.        
  715.         mutate(child)
  716.        
  717.         return child
  718. end
  719.  
  720. function removeStaleSpecies()
  721.         local survived = {}
  722.  
  723.         for s = 1,#pool.species do
  724.                 local species = pool.species[s]
  725.                
  726.                 table.sort(species.genomes, function (a,b)
  727.                         return (a.fitness > b.fitness)
  728.                 end)
  729.                
  730.                 if species.genomes[1].fitness > species.topFitness then
  731.                         species.topFitness = species.genomes[1].fitness
  732.                         species.staleness = 0
  733.                 else
  734.                         species.staleness = species.staleness + 1
  735.                 end
  736.                 if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  737.                         table.insert(survived, species)
  738.                 end
  739.         end
  740.  
  741.         pool.species = survived
  742. end
  743.  
  744. function removeWeakSpecies()
  745.         local survived = {}
  746.  
  747.         local sum = totalAverageFitness()
  748.         for s = 1,#pool.species do
  749.                 local species = pool.species[s]
  750.                 breed = math.floor(species.averageFitness / sum * Population)
  751.                 if breed >= 1 then
  752.                         table.insert(survived, species)
  753.                 end
  754.         end
  755.  
  756.         pool.species = survived
  757. end
  758.  
  759.  
  760. function addToSpecies(child)
  761.         local foundSpecies = false
  762.         for s=1,#pool.species do
  763.                 local species = pool.species[s]
  764.                 if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  765.                         table.insert(species.genomes, child)
  766.                         foundSpecies = true
  767.                 end
  768.         end
  769.        
  770.         if not foundSpecies then
  771.                 local childSpecies = newSpecies()
  772.                 table.insert(childSpecies.genomes, child)
  773.                 table.insert(pool.species, childSpecies)
  774.         end
  775. end
  776.  
  777. function newGeneration()
  778.         cullSpecies(false) -- Cull the bottom half of each species
  779.         rankGlobally()
  780.         removeStaleSpecies()
  781.         rankGlobally()
  782.         for s = 1,#pool.species do
  783.                 local species = pool.species[s]
  784.                 calculateAverageFitness(species)
  785.         end
  786.         removeWeakSpecies()
  787.         local sum = totalAverageFitness()
  788.         local children = {}
  789.         for s = 1,#pool.species do
  790.                 local species = pool.species[s]
  791.                 breed = math.floor(species.averageFitness / sum * Population) - 1
  792.                 for i=1,breed do
  793.                         table.insert(children, breedChild(species))
  794.                 end
  795.         end
  796.         cullSpecies(true) -- Cull all but the top member of each species
  797.         while #children + #pool.species < Population do
  798.                 local species = pool.species[math.random(1, #pool.species)]
  799.                 table.insert(children, breedChild(species))
  800.         end
  801.         for c=1,#children do
  802.                 local child = children[c]
  803.                 addToSpecies(child)
  804.         end
  805.        
  806.         pool.generation = pool.generation + 1
  807.        
  808.         writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  809. end
  810.        
  811. function initializePool()
  812.         pool = newPool()
  813.  
  814.         for i=1,Population do
  815.                 basic = basicGenome()
  816.                 addToSpecies(basic)
  817.         end
  818.  
  819.         initializeRun()
  820. end
  821.  
  822. function clearJoypad()
  823.         controller = {}
  824.         for b = 1,#ButtonNames do
  825.                 controller["P1 " .. ButtonNames[b]] = false
  826.         end
  827.         joypad.set(controller)
  828. end
  829.  
  830. function initializeRun()
  831.         savestate.load(Filename);
  832.         rightmost = 0
  833.         pool.currentFrame = 0
  834.         timeout = TimeoutConstant
  835.         clearJoypad()
  836.        
  837.         local species = pool.species[pool.currentSpecies]
  838.         local genome = species.genomes[pool.currentGenome]
  839.         generateNetwork(genome)
  840.         evaluateCurrent()
  841. end
  842.  
  843. function evaluateCurrent()
  844.         local species = pool.species[pool.currentSpecies]
  845.         local genome = species.genomes[pool.currentGenome]
  846.  
  847.         inputs = getInputs()
  848.         controller = evaluateNetwork(genome.network, inputs)
  849.        
  850.         if controller["P1 Left"] and controller["P1 Right"] then
  851.                 controller["P1 Left"] = false
  852.                 controller["P1 Right"] = false
  853.         end
  854.         if controller["P1 Up"] and controller["P1 Down"] then
  855.                 controller["P1 Up"] = false
  856.                 controller["P1 Down"] = false
  857.         end
  858.  
  859.         joypad.set(controller)
  860. end
  861.  
  862. if pool == nil then
  863.         initializePool()
  864. end
  865.  
  866.  
  867. function nextGenome()
  868.         pool.currentGenome = pool.currentGenome + 1
  869.         if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  870.                 pool.currentGenome = 1
  871.                 pool.currentSpecies = pool.currentSpecies+1
  872.                 if pool.currentSpecies > #pool.species then
  873.                         newGeneration()
  874.                         pool.currentSpecies = 1
  875.                 end
  876.         end
  877. end
  878.  
  879. function fitnessAlreadyMeasured()
  880.         local species = pool.species[pool.currentSpecies]
  881.         local genome = species.genomes[pool.currentGenome]
  882.        
  883.         return genome.fitness ~= 0
  884. end
  885.  
  886. function displayGenome(genome)
  887.         local network = genome.network
  888.         local cells = {}
  889.         local i = 1
  890.         local cell = {}
  891.         for dy=-BoxRadius,BoxRadius do
  892.                 for dx=-BoxRadius,BoxRadius do
  893.                         cell = {}
  894.                         cell.x = 50+5*dx
  895.                         cell.y = 70+5*dy
  896.                         cell.value = network.neurons[i].value
  897.                         cells[i] = cell
  898.                         i = i + 1
  899.                 end
  900.         end
  901.         local biasCell = {}
  902.         biasCell.x = 80
  903.         biasCell.y = 110
  904.         biasCell.value = network.neurons[Inputs].value
  905.         cells[Inputs] = biasCell
  906.        
  907.         for o = 1,Outputs do
  908.                 cell = {}
  909.                 cell.x = 220
  910.                 cell.y = 30 + 8 * o
  911.                 cell.value = network.neurons[MaxNodes + o].value
  912.                 cells[MaxNodes+o] = cell
  913.                 local color
  914.                 if cell.value > 0 then
  915.                         color = 0xFF0000FF
  916.                 else
  917.                         color = 0xFF000000
  918.                 end
  919.                 gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  920.         end
  921.        
  922.         for n,neuron in pairs(network.neurons) do
  923.                 cell = {}
  924.                 if n > Inputs and n <= MaxNodes then
  925.                         cell.x = 140
  926.                         cell.y = 40
  927.                         cell.value = neuron.value
  928.                         cells[n] = cell
  929.                 end
  930.         end
  931.        
  932.         for n=1,4 do
  933.                 for _,gene in pairs(genome.genes) do
  934.                         if gene.enabled then
  935.                                 local c1 = cells[gene.into]
  936.                                 local c2 = cells[gene.out]
  937.                                 if gene.into > Inputs and gene.into <= MaxNodes then
  938.                                         c1.x = 0.75*c1.x + 0.25*c2.x
  939.                                         if c1.x >= c2.x then
  940.                                                 c1.x = c1.x - 40
  941.                                         end
  942.                                         if c1.x < 90 then
  943.                                                 c1.x = 90
  944.                                         end
  945.                                        
  946.                                         if c1.x > 220 then
  947.                                                 c1.x = 220
  948.                                         end
  949.                                         c1.y = 0.75*c1.y + 0.25*c2.y
  950.                                        
  951.                                 end
  952.                                 if gene.out > Inputs and gene.out <= MaxNodes then
  953.                                         c2.x = 0.25*c1.x + 0.75*c2.x
  954.                                         if c1.x >= c2.x then
  955.                                                 c2.x = c2.x + 40
  956.                                         end
  957.                                         if c2.x < 90 then
  958.                                                 c2.x = 90
  959.                                         end
  960.                                         if c2.x > 220 then
  961.                                                 c2.x = 220
  962.                                         end
  963.                                         c2.y = 0.25*c1.y + 0.75*c2.y
  964.                                 end
  965.                         end
  966.                 end
  967.         end
  968.        
  969.         gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  970.         for n,cell in pairs(cells) do
  971.                 if n > Inputs or cell.value ~= 0 then
  972.                         local color = math.floor((cell.value+1)/2*256)
  973.                         if color > 255 then color = 255 end
  974.                         if color < 0 then color = 0 end
  975.                         local opacity = 0xFF000000
  976.                         if cell.value == 0 then
  977.                                 opacity = 0x50000000
  978.                         end
  979.                         color = opacity + color*0x10000 + color*0x100 + color
  980.                         gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  981.                 end
  982.         end
  983.         for _,gene in pairs(genome.genes) do
  984.                 if gene.enabled then
  985.                         local c1 = cells[gene.into]
  986.                         local c2 = cells[gene.out]
  987.                         local opacity = 0xA0000000
  988.                         if c1.value == 0 then
  989.                                 opacity = 0x20000000
  990.                         end
  991.                        
  992.                         local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  993.                         if gene.weight > 0 then
  994.                                 color = opacity + 0x8000 + 0x10000*color
  995.                         else
  996.                                 color = opacity + 0x800000 + 0x100*color
  997.                         end
  998.                         gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  999.                 end
  1000.         end
  1001.        
  1002.         gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1003.        
  1004.         if forms.ischecked(showMutationRates) then
  1005.                 local pos = 100
  1006.                 for mutation,rate in pairs(genome.mutationRates) do
  1007.                         gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1008.                         pos = pos + 8
  1009.                 end
  1010.         end
  1011. end
  1012.  
  1013. function writeFile(filename)
  1014.         local file = io.open(filename, "w")
  1015.         file:write(pool.generation .. "\n")
  1016.         file:write(pool.maxFitness .. "\n")
  1017.         file:write(#pool.species .. "\n")
  1018.         for n,species in pairs(pool.species) do
  1019.                 file:write(species.topFitness .. "\n")
  1020.                 file:write(species.staleness .. "\n")
  1021.                 file:write(#species.genomes .. "\n")
  1022.                 for m,genome in pairs(species.genomes) do
  1023.                         file:write(genome.fitness .. "\n")
  1024.                         file:write(genome.maxneuron .. "\n")
  1025.                         for mutation,rate in pairs(genome.mutationRates) do
  1026.                                 file:write(mutation .. "\n")
  1027.                                 file:write(rate .. "\n")
  1028.                         end
  1029.                         file:write("done\n")
  1030.                        
  1031.                         file:write(#genome.genes .. "\n")
  1032.                         for l,gene in pairs(genome.genes) do
  1033.                                 file:write(gene.into .. " ")
  1034.                                 file:write(gene.out .. " ")
  1035.                                 file:write(gene.weight .. " ")
  1036.                                 file:write(gene.innovation .. " ")
  1037.                                 if(gene.enabled) then
  1038.                                         file:write("1\n")
  1039.                                 else
  1040.                                         file:write("0\n")
  1041.                                 end
  1042.                         end
  1043.                 end
  1044.         end
  1045.         file:close()
  1046. end
  1047.  
  1048. function savePool()
  1049.         local filename = forms.gettext(saveLoadFile)
  1050.         writeFile(filename)
  1051. end
  1052.  
  1053. function loadFile(filename)
  1054.         local file = io.open(filename, "r")
  1055.         pool = newPool()
  1056.         pool.generation = file:read("*number")
  1057.         pool.maxFitness = file:read("*number")
  1058.         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1059.         local numSpecies = file:read("*number")
  1060.         for s=1,numSpecies do
  1061.                 local species = newSpecies()
  1062.                 table.insert(pool.species, species)
  1063.                 species.topFitness = file:read("*number")
  1064.                 species.staleness = file:read("*number")
  1065.                 local numGenomes = file:read("*number")
  1066.                 for g=1,numGenomes do
  1067.                         local genome = newGenome()
  1068.                         table.insert(species.genomes, genome)
  1069.                         genome.fitness = file:read("*number")
  1070.                         genome.maxneuron = file:read("*number")
  1071.                         local line = file:read("*line")
  1072.                         while line ~= "done" do
  1073.                                 genome.mutationRates[line] = file:read("*number")
  1074.                                 line = file:read("*line")
  1075.                         end
  1076.                         local numGenes = file:read("*number")
  1077.                         for n=1,numGenes do
  1078.                                 local gene = newGene()
  1079.                                 table.insert(genome.genes, gene)
  1080.                                 local enabled
  1081.                                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1082.                                 if enabled == 0 then
  1083.                                         gene.enabled = false
  1084.                                 else
  1085.                                         gene.enabled = true
  1086.                                 end
  1087.                                
  1088.                         end
  1089.                 end
  1090.         end
  1091.         file:close()
  1092.        
  1093.         while fitnessAlreadyMeasured() do
  1094.                 nextGenome()
  1095.         end
  1096.         initializeRun()
  1097.         pool.currentFrame = pool.currentFrame + 1
  1098. end
  1099.  
  1100. function loadPool()
  1101.         local filename = forms.gettext(saveLoadFile)
  1102.         loadFile(filename)
  1103. end
  1104.  
  1105. function playTop()
  1106.         local maxfitness = 0
  1107.         local maxs, maxg
  1108.         for s,species in pairs(pool.species) do
  1109.                 for g,genome in pairs(species.genomes) do
  1110.                         if genome.fitness > maxfitness then
  1111.                                 maxfitness = genome.fitness
  1112.                                 maxs = s
  1113.                                 maxg = g
  1114.                         end
  1115.                 end
  1116.         end
  1117.        
  1118.         pool.currentSpecies = maxs
  1119.         pool.currentGenome = maxg
  1120.         pool.maxFitness = maxfitness
  1121.         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1122.         initializeRun()
  1123.         pool.currentFrame = pool.currentFrame + 1
  1124.         return
  1125. end
  1126.  
  1127. function onExit()
  1128.         forms.destroy(form)
  1129. end
  1130.  
  1131. writeFile("temp.pool")
  1132.  
  1133. event.onexit(onExit)
  1134.  
  1135. form = forms.newform(200, 260, "Fitness")
  1136. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1137. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1138. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1139. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1140. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1141. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1142. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1143. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1144. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1145. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1146.  
  1147.  
  1148. while true do
  1149.         local backgroundColor = 0xD0FFFFFF
  1150.         if not forms.ischecked(hideBanner) then
  1151.                 gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1152.         end
  1153.  
  1154.         local species = pool.species[pool.currentSpecies]
  1155.         local genome = species.genomes[pool.currentGenome]
  1156.        
  1157.         if forms.ischecked(showNetwork) then
  1158.                 displayGenome(genome)
  1159.         end
  1160.        
  1161.         if pool.currentFrame%5 == 0 then
  1162.                 evaluateCurrent()
  1163.         end
  1164.  
  1165.         joypad.set(controller)
  1166.  
  1167.         getPositions()
  1168.         if marioX > rightmost then
  1169.                 rightmost = marioX
  1170.                 timeout = TimeoutConstant
  1171.         end
  1172.        
  1173.         timeout = timeout - 1
  1174.        
  1175.        
  1176.         local timeoutBonus = pool.currentFrame / 4
  1177.         if timeout + timeoutBonus <= 0 then
  1178.                 local fitness = rightmost - pool.currentFrame / 2
  1179.                 if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1180.                         fitness = fitness + 1000
  1181.                 end
  1182.                 if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1183.                         fitness = fitness + 1000
  1184.                 end
  1185.                 if fitness == 0 then
  1186.                         fitness = -1
  1187.                 end
  1188.                 genome.fitness = fitness
  1189.                
  1190.                 if fitness > pool.maxFitness then
  1191.                         pool.maxFitness = fitness
  1192.                         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1193.                         writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1194.                 end
  1195.                
  1196.                 console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1197.                 pool.currentSpecies = 1
  1198.                 pool.currentGenome = 1
  1199.                 while fitnessAlreadyMeasured() do
  1200.                         nextGenome()
  1201.                 end
  1202.                 initializeRun()
  1203.         end
  1204.  
  1205.         local measured = 0
  1206.         local total = 0
  1207.         for _,species in pairs(pool.species) do
  1208.                 for _,genome in pairs(species.genomes) do
  1209.                         total = total + 1
  1210.                         if genome.fitness ~= 0 then
  1211.                                 measured = measured + 1
  1212.                         end
  1213.                 end
  1214.         end
  1215.         if not forms.ischecked(hideBanner) then
  1216.                 gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1217.                 gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
  1218.                 gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1219.         end
  1220.                
  1221.         pool.currentFrame = pool.currentFrame + 1
  1222.  
  1223.         emu.frameadvance();
  1224. end
RAW Paste Data
Top