daily pastebin goal
11%
SHARE
TWEET

NEATEvolve.lua

SethBling Jun 13th, 2015 (edited) 493,655 Never
Upgrade to PRO!
ENDING IN00days00hours00mins00secs
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
  4. -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
  5. -- and put a copy in both the Lua folder and the root directory of BizHawk.
  6.  
  7. if gameinfo.getromname() == "Super Mario World (USA)" then
  8.         Filename = "DP1.state"
  9.         ButtonNames = {
  10.                 "A",
  11.                 "B",
  12.                 "X",
  13.                 "Y",
  14.                 "Up",
  15.                 "Down",
  16.                 "Left",
  17.                 "Right",
  18.         }
  19. elseif gameinfo.getromname() == "Super Mario Bros." then
  20.         Filename = "SMB1-1.state"
  21.         ButtonNames = {
  22.                 "A",
  23.                 "B",
  24.                 "Up",
  25.                 "Down",
  26.                 "Left",
  27.                 "Right",
  28.         }
  29. end
  30.  
  31. BoxRadius = 6
  32. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  33.  
  34. Inputs = InputSize+1
  35. Outputs = #ButtonNames
  36.  
  37. Population = 300
  38. DeltaDisjoint = 2.0
  39. DeltaWeights = 0.4
  40. DeltaThreshold = 1.0
  41.  
  42. StaleSpecies = 15
  43.  
  44. MutateConnectionsChance = 0.25
  45. PerturbChance = 0.90
  46. CrossoverChance = 0.75
  47. LinkMutationChance = 2.0
  48. NodeMutationChance = 0.50
  49. BiasMutationChance = 0.40
  50. StepSize = 0.1
  51. DisableMutationChance = 0.4
  52. EnableMutationChance = 0.2
  53.  
  54. TimeoutConstant = 20
  55.  
  56. MaxNodes = 1000000
  57.  
  58. function getPositions()
  59.         if gameinfo.getromname() == "Super Mario World (USA)" then
  60.                 marioX = memory.read_s16_le(0x94)
  61.                 marioY = memory.read_s16_le(0x96)
  62.                
  63.                 local layer1x = memory.read_s16_le(0x1A);
  64.                 local layer1y = memory.read_s16_le(0x1C);
  65.                
  66.                 screenX = marioX-layer1x
  67.                 screenY = marioY-layer1y
  68.         elseif gameinfo.getromname() == "Super Mario Bros." then
  69.                 marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  70.                 marioY = memory.readbyte(0x03B8)+16
  71.        
  72.                 screenX = memory.readbyte(0x03AD)
  73.                 screenY = memory.readbyte(0x03B8)
  74.         end
  75. end
  76.  
  77. function getTile(dx, dy)
  78.         if gameinfo.getromname() == "Super Mario World (USA)" then
  79.                 x = math.floor((marioX+dx+8)/16)
  80.                 y = math.floor((marioY+dy)/16)
  81.                
  82.                 return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  83.         elseif gameinfo.getromname() == "Super Mario Bros." then
  84.                 local x = marioX + dx + 8
  85.                 local y = marioY + dy - 16
  86.                 local page = math.floor(x/256)%2
  87.  
  88.                 local subx = math.floor((x%256)/16)
  89.                 local suby = math.floor((y - 32)/16)
  90.                 local addr = 0x500 + page*13*16+suby*16+subx
  91.                
  92.                 if suby >= 13 or suby < 0 then
  93.                         return 0
  94.                 end
  95.                
  96.                 if memory.readbyte(addr) ~= 0 then
  97.                         return 1
  98.                 else
  99.                         return 0
  100.                 end
  101.         end
  102. end
  103.  
  104. function getSprites()
  105.         if gameinfo.getromname() == "Super Mario World (USA)" then
  106.                 local sprites = {}
  107.                 for slot=0,11 do
  108.                         local status = memory.readbyte(0x14C8+slot)
  109.                         if status ~= 0 then
  110.                                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  111.                                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  112.                                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  113.                         end
  114.                 end            
  115.                
  116.                 return sprites
  117.         elseif gameinfo.getromname() == "Super Mario Bros." then
  118.                 local sprites = {}
  119.                 for slot=0,4 do
  120.                         local enemy = memory.readbyte(0xF+slot)
  121.                         if enemy ~= 0 then
  122.                                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  123.                                 local ey = memory.readbyte(0xCF + slot)+24
  124.                                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  125.                         end
  126.                 end
  127.                
  128.                 return sprites
  129.         end
  130. end
  131.  
  132. function getExtendedSprites()
  133.         if gameinfo.getromname() == "Super Mario World (USA)" then
  134.                 local extended = {}
  135.                 for slot=0,11 do
  136.                         local number = memory.readbyte(0x170B+slot)
  137.                         if number ~= 0 then
  138.                                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  139.                                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  140.                                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  141.                         end
  142.                 end            
  143.                
  144.                 return extended
  145.         elseif gameinfo.getromname() == "Super Mario Bros." then
  146.                 return {}
  147.         end
  148. end
  149.  
  150. function getInputs()
  151.         getPositions()
  152.        
  153.         sprites = getSprites()
  154.         extended = getExtendedSprites()
  155.        
  156.         local inputs = {}
  157.        
  158.         for dy=-BoxRadius*16,BoxRadius*16,16 do
  159.                 for dx=-BoxRadius*16,BoxRadius*16,16 do
  160.                         inputs[#inputs+1] = 0
  161.                        
  162.                         tile = getTile(dx, dy)
  163.                         if tile == 1 and marioY+dy < 0x1B0 then
  164.                                 inputs[#inputs] = 1
  165.                         end
  166.                        
  167.                         for i = 1,#sprites do
  168.                                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  169.                                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  170.                                 if distx <= 8 and disty <= 8 then
  171.                                         inputs[#inputs] = -1
  172.                                 end
  173.                         end
  174.  
  175.                         for i = 1,#extended do
  176.                                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  177.                                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  178.                                 if distx < 8 and disty < 8 then
  179.                                         inputs[#inputs] = -1
  180.                                 end
  181.                         end
  182.                 end
  183.         end
  184.        
  185.         --mariovx = memory.read_s8(0x7B)
  186.         --mariovy = memory.read_s8(0x7D)
  187.        
  188.         return inputs
  189. end
  190.  
  191. function sigmoid(x)
  192.         return 2/(1+math.exp(-4.9*x))-1
  193. end
  194.  
  195. function newInnovation()
  196.         pool.innovation = pool.innovation + 1
  197.         return pool.innovation
  198. end
  199.  
  200. function newPool()
  201.         local pool = {}
  202.         pool.species = {}
  203.         pool.generation = 0
  204.         pool.innovation = Outputs
  205.         pool.currentSpecies = 1
  206.         pool.currentGenome = 1
  207.         pool.currentFrame = 0
  208.         pool.maxFitness = 0
  209.        
  210.         return pool
  211. end
  212.  
  213. function newSpecies()
  214.         local species = {}
  215.         species.topFitness = 0
  216.         species.staleness = 0
  217.         species.genomes = {}
  218.         species.averageFitness = 0
  219.        
  220.         return species
  221. end
  222.  
  223. function newGenome()
  224.         local genome = {}
  225.         genome.genes = {}
  226.         genome.fitness = 0
  227.         genome.adjustedFitness = 0
  228.         genome.network = {}
  229.         genome.maxneuron = 0
  230.         genome.globalRank = 0
  231.         genome.mutationRates = {}
  232.         genome.mutationRates["connections"] = MutateConnectionsChance
  233.         genome.mutationRates["link"] = LinkMutationChance
  234.         genome.mutationRates["bias"] = BiasMutationChance
  235.         genome.mutationRates["node"] = NodeMutationChance
  236.         genome.mutationRates["enable"] = EnableMutationChance
  237.         genome.mutationRates["disable"] = DisableMutationChance
  238.         genome.mutationRates["step"] = StepSize
  239.        
  240.         return genome
  241. end
  242.  
  243. function copyGenome(genome)
  244.         local genome2 = newGenome()
  245.         for g=1,#genome.genes do
  246.                 table.insert(genome2.genes, copyGene(genome.genes[g]))
  247.         end
  248.         genome2.maxneuron = genome.maxneuron
  249.         genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  250.         genome2.mutationRates["link"] = genome.mutationRates["link"]
  251.         genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  252.         genome2.mutationRates["node"] = genome.mutationRates["node"]
  253.         genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  254.         genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  255.        
  256.         return genome2
  257. end
  258.  
  259. function basicGenome()
  260.         local genome = newGenome()
  261.         local innovation = 1
  262.  
  263.         genome.maxneuron = Inputs
  264.         mutate(genome)
  265.        
  266.         return genome
  267. end
  268.  
  269. function newGene()
  270.         local gene = {}
  271.         gene.into = 0
  272.         gene.out = 0
  273.         gene.weight = 0.0
  274.         gene.enabled = true
  275.         gene.innovation = 0
  276.        
  277.         return gene
  278. end
  279.  
  280. function copyGene(gene)
  281.         local gene2 = newGene()
  282.         gene2.into = gene.into
  283.         gene2.out = gene.out
  284.         gene2.weight = gene.weight
  285.         gene2.enabled = gene.enabled
  286.         gene2.innovation = gene.innovation
  287.        
  288.         return gene2
  289. end
  290.  
  291. function newNeuron()
  292.         local neuron = {}
  293.         neuron.incoming = {}
  294.         neuron.value = 0.0
  295.        
  296.         return neuron
  297. end
  298.  
  299. function generateNetwork(genome)
  300.         local network = {}
  301.         network.neurons = {}
  302.        
  303.         for i=1,Inputs do
  304.                 network.neurons[i] = newNeuron()
  305.         end
  306.        
  307.         for o=1,Outputs do
  308.                 network.neurons[MaxNodes+o] = newNeuron()
  309.         end
  310.        
  311.         table.sort(genome.genes, function (a,b)
  312.                 return (a.out < b.out)
  313.         end)
  314.         for i=1,#genome.genes do
  315.                 local gene = genome.genes[i]
  316.                 if gene.enabled then
  317.                         if network.neurons[gene.out] == nil then
  318.                                 network.neurons[gene.out] = newNeuron()
  319.                         end
  320.                         local neuron = network.neurons[gene.out]
  321.                         table.insert(neuron.incoming, gene)
  322.                         if network.neurons[gene.into] == nil then
  323.                                 network.neurons[gene.into] = newNeuron()
  324.                         end
  325.                 end
  326.         end
  327.        
  328.         genome.network = network
  329. end
  330.  
  331. function evaluateNetwork(network, inputs)
  332.         table.insert(inputs, 1)
  333.         if #inputs ~= Inputs then
  334.                 console.writeline("Incorrect number of neural network inputs.")
  335.                 return {}
  336.         end
  337.        
  338.         for i=1,Inputs do
  339.                 network.neurons[i].value = inputs[i]
  340.         end
  341.        
  342.         for _,neuron in pairs(network.neurons) do
  343.                 local sum = 0
  344.                 for j = 1,#neuron.incoming do
  345.                         local incoming = neuron.incoming[j]
  346.                         local other = network.neurons[incoming.into]
  347.                         sum = sum + incoming.weight * other.value
  348.                 end
  349.                
  350.                 if #neuron.incoming > 0 then
  351.                         neuron.value = sigmoid(sum)
  352.                 end
  353.         end
  354.        
  355.         local outputs = {}
  356.         for o=1,Outputs do
  357.                 local button = "P1 " .. ButtonNames[o]
  358.                 if network.neurons[MaxNodes+o].value > 0 then
  359.                         outputs[button] = true
  360.                 else
  361.                         outputs[button] = false
  362.                 end
  363.         end
  364.        
  365.         return outputs
  366. end
  367.  
  368. function crossover(g1, g2)
  369.         -- Make sure g1 is the higher fitness genome
  370.         if g2.fitness > g1.fitness then
  371.                 tempg = g1
  372.                 g1 = g2
  373.                 g2 = tempg
  374.         end
  375.  
  376.         local child = newGenome()
  377.        
  378.         local innovations2 = {}
  379.         for i=1,#g2.genes do
  380.                 local gene = g2.genes[i]
  381.                 innovations2[gene.innovation] = gene
  382.         end
  383.        
  384.         for i=1,#g1.genes do
  385.                 local gene1 = g1.genes[i]
  386.                 local gene2 = innovations2[gene1.innovation]
  387.                 if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  388.                         table.insert(child.genes, copyGene(gene2))
  389.                 else
  390.                         table.insert(child.genes, copyGene(gene1))
  391.                 end
  392.         end
  393.        
  394.         child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  395.        
  396.         for mutation,rate in pairs(g1.mutationRates) do
  397.                 child.mutationRates[mutation] = rate
  398.         end
  399.        
  400.         return child
  401. end
  402.  
  403. function randomNeuron(genes, nonInput)
  404.         local neurons = {}
  405.         if not nonInput then
  406.                 for i=1,Inputs do
  407.                         neurons[i] = true
  408.                 end
  409.         end
  410.         for o=1,Outputs do
  411.                 neurons[MaxNodes+o] = true
  412.         end
  413.         for i=1,#genes do
  414.                 if (not nonInput) or genes[i].into > Inputs then
  415.                         neurons[genes[i].into] = true
  416.                 end
  417.                 if (not nonInput) or genes[i].out > Inputs then
  418.                         neurons[genes[i].out] = true
  419.                 end
  420.         end
  421.  
  422.         local count = 0
  423.         for _,_ in pairs(neurons) do
  424.                 count = count + 1
  425.         end
  426.         local n = math.random(1, count)
  427.        
  428.         for k,v in pairs(neurons) do
  429.                 n = n-1
  430.                 if n == 0 then
  431.                         return k
  432.                 end
  433.         end
  434.        
  435.         return 0
  436. end
  437.  
  438. function containsLink(genes, link)
  439.         for i=1,#genes do
  440.                 local gene = genes[i]
  441.                 if gene.into == link.into and gene.out == link.out then
  442.                         return true
  443.                 end
  444.         end
  445. end
  446.  
  447. function pointMutate(genome)
  448.         local step = genome.mutationRates["step"]
  449.        
  450.         for i=1,#genome.genes do
  451.                 local gene = genome.genes[i]
  452.                 if math.random() < PerturbChance then
  453.                         gene.weight = gene.weight + math.random() * step*2 - step
  454.                 else
  455.                         gene.weight = math.random()*4-2
  456.                 end
  457.         end
  458. end
  459.  
  460. function linkMutate(genome, forceBias)
  461.         local neuron1 = randomNeuron(genome.genes, false)
  462.         local neuron2 = randomNeuron(genome.genes, true)
  463.          
  464.         local newLink = newGene()
  465.         if neuron1 <= Inputs and neuron2 <= Inputs then
  466.                 --Both input nodes
  467.                 return
  468.         end
  469.         if neuron2 <= Inputs then
  470.                 -- Swap output and input
  471.                 local temp = neuron1
  472.                 neuron1 = neuron2
  473.                 neuron2 = temp
  474.         end
  475.  
  476.         newLink.into = neuron1
  477.         newLink.out = neuron2
  478.         if forceBias then
  479.                 newLink.into = Inputs
  480.         end
  481.        
  482.         if containsLink(genome.genes, newLink) then
  483.                 return
  484.         end
  485.         newLink.innovation = newInnovation()
  486.         newLink.weight = math.random()*4-2
  487.        
  488.         table.insert(genome.genes, newLink)
  489. end
  490.  
  491. function nodeMutate(genome)
  492.         if #genome.genes == 0 then
  493.                 return
  494.         end
  495.  
  496.         genome.maxneuron = genome.maxneuron + 1
  497.  
  498.         local gene = genome.genes[math.random(1,#genome.genes)]
  499.         if not gene.enabled then
  500.                 return
  501.         end
  502.         gene.enabled = false
  503.        
  504.         local gene1 = copyGene(gene)
  505.         gene1.out = genome.maxneuron
  506.         gene1.weight = 1.0
  507.         gene1.innovation = newInnovation()
  508.         gene1.enabled = true
  509.         table.insert(genome.genes, gene1)
  510.        
  511.         local gene2 = copyGene(gene)
  512.         gene2.into = genome.maxneuron
  513.         gene2.innovation = newInnovation()
  514.         gene2.enabled = true
  515.         table.insert(genome.genes, gene2)
  516. end
  517.  
  518. function enableDisableMutate(genome, enable)
  519.         local candidates = {}
  520.         for _,gene in pairs(genome.genes) do
  521.                 if gene.enabled == not enable then
  522.                         table.insert(candidates, gene)
  523.                 end
  524.         end
  525.        
  526.         if #candidates == 0 then
  527.                 return
  528.         end
  529.        
  530.         local gene = candidates[math.random(1,#candidates)]
  531.         gene.enabled = not gene.enabled
  532. end
  533.  
  534. function mutate(genome)
  535.         for mutation,rate in pairs(genome.mutationRates) do
  536.                 if math.random(1,2) == 1 then
  537.                         genome.mutationRates[mutation] = 0.95*rate
  538.                 else
  539.                         genome.mutationRates[mutation] = 1.05263*rate
  540.                 end
  541.         end
  542.  
  543.         if math.random() < genome.mutationRates["connections"] then
  544.                 pointMutate(genome)
  545.         end
  546.        
  547.         local p = genome.mutationRates["link"]
  548.         while p > 0 do
  549.                 if math.random() < p then
  550.                         linkMutate(genome, false)
  551.                 end
  552.                 p = p - 1
  553.         end
  554.  
  555.         p = genome.mutationRates["bias"]
  556.         while p > 0 do
  557.                 if math.random() < p then
  558.                         linkMutate(genome, true)
  559.                 end
  560.                 p = p - 1
  561.         end
  562.        
  563.         p = genome.mutationRates["node"]
  564.         while p > 0 do
  565.                 if math.random() < p then
  566.                         nodeMutate(genome)
  567.                 end
  568.                 p = p - 1
  569.         end
  570.        
  571.         p = genome.mutationRates["enable"]
  572.         while p > 0 do
  573.                 if math.random() < p then
  574.                         enableDisableMutate(genome, true)
  575.                 end
  576.                 p = p - 1
  577.         end
  578.  
  579.         p = genome.mutationRates["disable"]
  580.         while p > 0 do
  581.                 if math.random() < p then
  582.                         enableDisableMutate(genome, false)
  583.                 end
  584.                 p = p - 1
  585.         end
  586. end
  587.  
  588. function disjoint(genes1, genes2)
  589.         local i1 = {}
  590.         for i = 1,#genes1 do
  591.                 local gene = genes1[i]
  592.                 i1[gene.innovation] = true
  593.         end
  594.  
  595.         local i2 = {}
  596.         for i = 1,#genes2 do
  597.                 local gene = genes2[i]
  598.                 i2[gene.innovation] = true
  599.         end
  600.        
  601.         local disjointGenes = 0
  602.         for i = 1,#genes1 do
  603.                 local gene = genes1[i]
  604.                 if not i2[gene.innovation] then
  605.                         disjointGenes = disjointGenes+1
  606.                 end
  607.         end
  608.        
  609.         for i = 1,#genes2 do
  610.                 local gene = genes2[i]
  611.                 if not i1[gene.innovation] then
  612.                         disjointGenes = disjointGenes+1
  613.                 end
  614.         end
  615.        
  616.         local n = math.max(#genes1, #genes2)
  617.        
  618.         return disjointGenes / n
  619. end
  620.  
  621. function weights(genes1, genes2)
  622.         local i2 = {}
  623.         for i = 1,#genes2 do
  624.                 local gene = genes2[i]
  625.                 i2[gene.innovation] = gene
  626.         end
  627.  
  628.         local sum = 0
  629.         local coincident = 0
  630.         for i = 1,#genes1 do
  631.                 local gene = genes1[i]
  632.                 if i2[gene.innovation] ~= nil then
  633.                         local gene2 = i2[gene.innovation]
  634.                         sum = sum + math.abs(gene.weight - gene2.weight)
  635.                         coincident = coincident + 1
  636.                 end
  637.         end
  638.        
  639.         return sum / coincident
  640. end
  641.        
  642. function sameSpecies(genome1, genome2)
  643.         local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  644.         local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  645.         return dd + dw < DeltaThreshold
  646. end
  647.  
  648. function rankGlobally()
  649.         local global = {}
  650.         for s = 1,#pool.species do
  651.                 local species = pool.species[s]
  652.                 for g = 1,#species.genomes do
  653.                         table.insert(global, species.genomes[g])
  654.                 end
  655.         end
  656.         table.sort(global, function (a,b)
  657.                 return (a.fitness < b.fitness)
  658.         end)
  659.        
  660.         for g=1,#global do
  661.                 global[g].globalRank = g
  662.         end
  663. end
  664.  
  665. function calculateAverageFitness(species)
  666.         local total = 0
  667.        
  668.         for g=1,#species.genomes do
  669.                 local genome = species.genomes[g]
  670.                 total = total + genome.globalRank
  671.         end
  672.        
  673.         species.averageFitness = total / #species.genomes
  674. end
  675.  
  676. function totalAverageFitness()
  677.         local total = 0
  678.         for s = 1,#pool.species do
  679.                 local species = pool.species[s]
  680.                 total = total + species.averageFitness
  681.         end
  682.  
  683.         return total
  684. end
  685.  
  686. function cullSpecies(cutToOne)
  687.         for s = 1,#pool.species do
  688.                 local species = pool.species[s]
  689.                
  690.                 table.sort(species.genomes, function (a,b)
  691.                         return (a.fitness > b.fitness)
  692.                 end)
  693.                
  694.                 local remaining = math.ceil(#species.genomes/2)
  695.                 if cutToOne then
  696.                         remaining = 1
  697.                 end
  698.                 while #species.genomes > remaining do
  699.                         table.remove(species.genomes)
  700.                 end
  701.         end
  702. end
  703.  
  704. function breedChild(species)
  705.         local child = {}
  706.         if math.random() < CrossoverChance then
  707.                 g1 = species.genomes[math.random(1, #species.genomes)]
  708.                 g2 = species.genomes[math.random(1, #species.genomes)]
  709.                 child = crossover(g1, g2)
  710.         else
  711.                 g = species.genomes[math.random(1, #species.genomes)]
  712.                 child = copyGenome(g)
  713.         end
  714.        
  715.         mutate(child)
  716.        
  717.         return child
  718. end
  719.  
  720. function removeStaleSpecies()
  721.         local survived = {}
  722.  
  723.         for s = 1,#pool.species do
  724.                 local species = pool.species[s]
  725.                
  726.                 table.sort(species.genomes, function (a,b)
  727.                         return (a.fitness > b.fitness)
  728.                 end)
  729.                
  730.                 if species.genomes[1].fitness > species.topFitness then
  731.                         species.topFitness = species.genomes[1].fitness
  732.                         species.staleness = 0
  733.                 else
  734.                         species.staleness = species.staleness + 1
  735.                 end
  736.                 if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  737.                         table.insert(survived, species)
  738.                 end
  739.         end
  740.  
  741.         pool.species = survived
  742. end
  743.  
  744. function removeWeakSpecies()
  745.         local survived = {}
  746.  
  747.         local sum = totalAverageFitness()
  748.         for s = 1,#pool.species do
  749.                 local species = pool.species[s]
  750.                 breed = math.floor(species.averageFitness / sum * Population)
  751.                 if breed >= 1 then
  752.                         table.insert(survived, species)
  753.                 end
  754.         end
  755.  
  756.         pool.species = survived
  757. end
  758.  
  759.  
  760. function addToSpecies(child)
  761.         local foundSpecies = false
  762.         for s=1,#pool.species do
  763.                 local species = pool.species[s]
  764.                 if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  765.                         table.insert(species.genomes, child)
  766.                         foundSpecies = true
  767.                 end
  768.         end
  769.        
  770.         if not foundSpecies then
  771.                 local childSpecies = newSpecies()
  772.                 table.insert(childSpecies.genomes, child)
  773.                 table.insert(pool.species, childSpecies)
  774.         end
  775. end
  776.  
  777. function newGeneration()
  778.         cullSpecies(false) -- Cull the bottom half of each species
  779.         rankGlobally()
  780.         removeStaleSpecies()
  781.         rankGlobally()
  782.         for s = 1,#pool.species do
  783.                 local species = pool.species[s]
  784.                 calculateAverageFitness(species)
  785.         end
  786.         removeWeakSpecies()
  787.         local sum = totalAverageFitness()
  788.         local children = {}
  789.         for s = 1,#pool.species do
  790.                 local species = pool.species[s]
  791.                 breed = math.floor(species.averageFitness / sum * Population) - 1
  792.                 for i=1,breed do
  793.                         table.insert(children, breedChild(species))
  794.                 end
  795.         end
  796.         cullSpecies(true) -- Cull all but the top member of each species
  797.         while #children + #pool.species < Population do
  798.                 local species = pool.species[math.random(1, #pool.species)]
  799.                 table.insert(children, breedChild(species))
  800.         end
  801.         for c=1,#children do
  802.                 local child = children[c]
  803.                 addToSpecies(child)
  804.         end
  805.        
  806.         pool.generation = pool.generation + 1
  807.        
  808.         writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  809. end
  810.        
  811. function initializePool()
  812.         pool = newPool()
  813.  
  814.         for i=1,Population do
  815.                 basic = basicGenome()
  816.                 addToSpecies(basic)
  817.         end
  818.  
  819.         initializeRun()
  820. end
  821.  
  822. function clearJoypad()
  823.         controller = {}
  824.         for b = 1,#ButtonNames do
  825.                 controller["P1 " .. ButtonNames[b]] = false
  826.         end
  827.         joypad.set(controller)
  828. end
  829.  
  830. function initializeRun()
  831.         savestate.load(Filename);
  832.         rightmost = 0
  833.         pool.currentFrame = 0
  834.         timeout = TimeoutConstant
  835.         clearJoypad()
  836.        
  837.         local species = pool.species[pool.currentSpecies]
  838.         local genome = species.genomes[pool.currentGenome]
  839.         generateNetwork(genome)
  840.         evaluateCurrent()
  841. end
  842.  
  843. function evaluateCurrent()
  844.         local species = pool.species[pool.currentSpecies]
  845.         local genome = species.genomes[pool.currentGenome]
  846.  
  847.         inputs = getInputs()
  848.         controller = evaluateNetwork(genome.network, inputs)
  849.        
  850.         if controller["P1 Left"] and controller["P1 Right"] then
  851.                 controller["P1 Left"] = false
  852.                 controller["P1 Right"] = false
  853.         end
  854.         if controller["P1 Up"] and controller["P1 Down"] then
  855.                 controller["P1 Up"] = false
  856.                 controller["P1 Down"] = false
  857.         end
  858.  
  859.         joypad.set(controller)
  860. end
  861.  
  862. if pool == nil then
  863.         initializePool()
  864. end
  865.  
  866.  
  867. function nextGenome()
  868.         pool.currentGenome = pool.currentGenome + 1
  869.         if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  870.                 pool.currentGenome = 1
  871.                 pool.currentSpecies = pool.currentSpecies+1
  872.                 if pool.currentSpecies > #pool.species then
  873.                         newGeneration()
  874.                         pool.currentSpecies = 1
  875.                 end
  876.         end
  877. end
  878.  
  879. function fitnessAlreadyMeasured()
  880.         local species = pool.species[pool.currentSpecies]
  881.         local genome = species.genomes[pool.currentGenome]
  882.        
  883.         return genome.fitness ~= 0
  884. end
  885.  
  886. function displayGenome(genome)
  887.         local network = genome.network
  888.         local cells = {}
  889.         local i = 1
  890.         local cell = {}
  891.         for dy=-BoxRadius,BoxRadius do
  892.                 for dx=-BoxRadius,BoxRadius do
  893.                         cell = {}
  894.                         cell.x = 50+5*dx
  895.                         cell.y = 70+5*dy
  896.                         cell.value = network.neurons[i].value
  897.                         cells[i] = cell
  898.                         i = i + 1
  899.                 end
  900.         end
  901.         local biasCell = {}
  902.         biasCell.x = 80
  903.         biasCell.y = 110
  904.         biasCell.value = network.neurons[Inputs].value
  905.         cells[Inputs] = biasCell
  906.        
  907.         for o = 1,Outputs do
  908.                 cell = {}
  909.                 cell.x = 220
  910.                 cell.y = 30 + 8 * o
  911.                 cell.value = network.neurons[MaxNodes + o].value
  912.                 cells[MaxNodes+o] = cell
  913.                 local color
  914.                 if cell.value > 0 then
  915.                         color = 0xFF0000FF
  916.                 else
  917.                         color = 0xFF000000
  918.                 end
  919.                 gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  920.         end
  921.        
  922.         for n,neuron in pairs(network.neurons) do
  923.                 cell = {}
  924.                 if n > Inputs and n <= MaxNodes then
  925.                         cell.x = 140
  926.                         cell.y = 40
  927.                         cell.value = neuron.value
  928.                         cells[n] = cell
  929.                 end
  930.         end
  931.        
  932.         for n=1,4 do
  933.                 for _,gene in pairs(genome.genes) do
  934.                         if gene.enabled then
  935.                                 local c1 = cells[gene.into]
  936.                                 local c2 = cells[gene.out]
  937.                                 if gene.into > Inputs and gene.into <= MaxNodes then
  938.                                         c1.x = 0.75*c1.x + 0.25*c2.x
  939.                                         if c1.x >= c2.x then
  940.                                                 c1.x = c1.x - 40
  941.                                         end
  942.                                         if c1.x < 90 then
  943.                                                 c1.x = 90
  944.                                         end
  945.                                        
  946.                                         if c1.x > 220 then
  947.                                                 c1.x = 220
  948.                                         end
  949.                                         c1.y = 0.75*c1.y + 0.25*c2.y
  950.                                        
  951.                                 end
  952.                                 if gene.out > Inputs and gene.out <= MaxNodes then
  953.                                         c2.x = 0.25*c1.x + 0.75*c2.x
  954.                                         if c1.x >= c2.x then
  955.                                                 c2.x = c2.x + 40
  956.                                         end
  957.                                         if c2.x < 90 then
  958.                                                 c2.x = 90
  959.                                         end
  960.                                         if c2.x > 220 then
  961.                                                 c2.x = 220
  962.                                         end
  963.                                         c2.y = 0.25*c1.y + 0.75*c2.y
  964.                                 end
  965.                         end
  966.                 end
  967.         end
  968.        
  969.         gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  970.         for n,cell in pairs(cells) do
  971.                 if n > Inputs or cell.value ~= 0 then
  972.                         local color = math.floor((cell.value+1)/2*256)
  973.                         if color > 255 then color = 255 end
  974.                         if color < 0 then color = 0 end
  975.                         local opacity = 0xFF000000
  976.                         if cell.value == 0 then
  977.                                 opacity = 0x50000000
  978.                         end
  979.                         color = opacity + color*0x10000 + color*0x100 + color
  980.                         gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  981.                 end
  982.         end
  983.         for _,gene in pairs(genome.genes) do
  984.                 if gene.enabled then
  985.                         local c1 = cells[gene.into]
  986.                         local c2 = cells[gene.out]
  987.                         local opacity = 0xA0000000
  988.                         if c1.value == 0 then
  989.                                 opacity = 0x20000000
  990.                         end
  991.                        
  992.                         local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  993.                         if gene.weight > 0 then
  994.                                 color = opacity + 0x8000 + 0x10000*color
  995.                         else
  996.                                 color = opacity + 0x800000 + 0x100*color
  997.                         end
  998.                         gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  999.                 end
  1000.         end
  1001.        
  1002.         gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  1003.        
  1004.         if forms.ischecked(showMutationRates) then
  1005.                 local pos = 100
  1006.                 for mutation,rate in pairs(genome.mutationRates) do
  1007.                         gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1008.                         pos = pos + 8
  1009.                 end
  1010.         end
  1011. end
  1012.  
  1013. function writeFile(filename)
  1014.         local file = io.open(filename, "w")
  1015.         file:write(pool.generation .. "\n")
  1016.         file:write(pool.maxFitness .. "\n")
  1017.         file:write(#pool.species .. "\n")
  1018.         for n,species in pairs(pool.species) do
  1019.                 file:write(species.topFitness .. "\n")
  1020.                 file:write(species.staleness .. "\n")
  1021.                 file:write(#species.genomes .. "\n")
  1022.                 for m,genome in pairs(species.genomes) do
  1023.                         file:write(genome.fitness .. "\n")
  1024.                         file:write(genome.maxneuron .. "\n")
  1025.                         for mutation,rate in pairs(genome.mutationRates) do
  1026.                                 file:write(mutation .. "\n")
  1027.                                 file:write(rate .. "\n")
  1028.                         end
  1029.                         file:write("done\n")
  1030.                        
  1031.                         file:write(#genome.genes .. "\n")
  1032.                         for l,gene in pairs(genome.genes) do
  1033.                                 file:write(gene.into .. " ")
  1034.                                 file:write(gene.out .. " ")
  1035.                                 file:write(gene.weight .. " ")
  1036.                                 file:write(gene.innovation .. " ")
  1037.                                 if(gene.enabled) then
  1038.                                         file:write("1\n")
  1039.                                 else
  1040.                                         file:write("0\n")
  1041.                                 end
  1042.                         end
  1043.                 end
  1044.         end
  1045.         file:close()
  1046. end
  1047.  
  1048. function savePool()
  1049.         local filename = forms.gettext(saveLoadFile)
  1050.         writeFile(filename)
  1051. end
  1052.  
  1053. function loadFile(filename)
  1054.         local file = io.open(filename, "r")
  1055.         pool = newPool()
  1056.         pool.generation = file:read("*number")
  1057.         pool.maxFitness = file:read("*number")
  1058.         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1059.         local numSpecies = file:read("*number")
  1060.         for s=1,numSpecies do
  1061.                 local species = newSpecies()
  1062.                 table.insert(pool.species, species)
  1063.                 species.topFitness = file:read("*number")
  1064.                 species.staleness = file:read("*number")
  1065.                 local numGenomes = file:read("*number")
  1066.                 for g=1,numGenomes do
  1067.                         local genome = newGenome()
  1068.                         table.insert(species.genomes, genome)
  1069.                         genome.fitness = file:read("*number")
  1070.                         genome.maxneuron = file:read("*number")
  1071.                         local line = file:read("*line")
  1072.                         while line ~= "done" do
  1073.                                 genome.mutationRates[line] = file:read("*number")
  1074.                                 line = file:read("*line")
  1075.                         end
  1076.                         local numGenes = file:read("*number")
  1077.                         for n=1,numGenes do
  1078.                                 local gene = newGene()
  1079.                                 table.insert(genome.genes, gene)
  1080.                                 local enabled
  1081.                                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1082.                                 if enabled == 0 then
  1083.                                         gene.enabled = false
  1084.                                 else
  1085.                                         gene.enabled = true
  1086.                                 end
  1087.                                
  1088.                         end
  1089.                 end
  1090.         end
  1091.         file:close()
  1092.        
  1093.         while fitnessAlreadyMeasured() do
  1094.                 nextGenome()
  1095.         end
  1096.         initializeRun()
  1097.         pool.currentFrame = pool.currentFrame + 1
  1098. end
  1099.  
  1100. function loadPool()
  1101.         local filename = forms.gettext(saveLoadFile)
  1102.         loadFile(filename)
  1103. end
  1104.  
  1105. function playTop()
  1106.         local maxfitness = 0
  1107.         local maxs, maxg
  1108.         for s,species in pairs(pool.species) do
  1109.                 for g,genome in pairs(species.genomes) do
  1110.                         if genome.fitness > maxfitness then
  1111.                                 maxfitness = genome.fitness
  1112.                                 maxs = s
  1113.                                 maxg = g
  1114.                         end
  1115.                 end
  1116.         end
  1117.        
  1118.         pool.currentSpecies = maxs
  1119.         pool.currentGenome = maxg
  1120.         pool.maxFitness = maxfitness
  1121.         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1122.         initializeRun()
  1123.         pool.currentFrame = pool.currentFrame + 1
  1124.         return
  1125. end
  1126.  
  1127. function onExit()
  1128.         forms.destroy(form)
  1129. end
  1130.  
  1131. writeFile("temp.pool")
  1132.  
  1133. event.onexit(onExit)
  1134.  
  1135. form = forms.newform(200, 260, "Fitness")
  1136. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1137. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1138. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1139. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1140. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1141. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1142. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1143. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1144. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1145. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1146.  
  1147.  
  1148. while true do
  1149.         local backgroundColor = 0xD0FFFFFF
  1150.         if not forms.ischecked(hideBanner) then
  1151.                 gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1152.         end
  1153.  
  1154.         local species = pool.species[pool.currentSpecies]
  1155.         local genome = species.genomes[pool.currentGenome]
  1156.        
  1157.         if forms.ischecked(showNetwork) then
  1158.                 displayGenome(genome)
  1159.         end
  1160.        
  1161.         if pool.currentFrame%5 == 0 then
  1162.                 evaluateCurrent()
  1163.         end
  1164.  
  1165.         joypad.set(controller)
  1166.  
  1167.         getPositions()
  1168.         if marioX > rightmost then
  1169.                 rightmost = marioX
  1170.                 timeout = TimeoutConstant
  1171.         end
  1172.        
  1173.         timeout = timeout - 1
  1174.        
  1175.        
  1176.         local timeoutBonus = pool.currentFrame / 4
  1177.         if timeout + timeoutBonus <= 0 then
  1178.                 local fitness = rightmost - pool.currentFrame / 2
  1179.                 if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1180.                         fitness = fitness + 1000
  1181.                 end
  1182.                 if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1183.                         fitness = fitness + 1000
  1184.                 end
  1185.                 if fitness == 0 then
  1186.                         fitness = -1
  1187.                 end
  1188.                 genome.fitness = fitness
  1189.                
  1190.                 if fitness > pool.maxFitness then
  1191.                         pool.maxFitness = fitness
  1192.                         forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1193.                         writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1194.                 end
  1195.                
  1196.                 console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1197.                 pool.currentSpecies = 1
  1198.                 pool.currentGenome = 1
  1199.                 while fitnessAlreadyMeasured() do
  1200.                         nextGenome()
  1201.                 end
  1202.                 initializeRun()
  1203.         end
  1204.  
  1205.         local measured = 0
  1206.         local total = 0
  1207.         for _,species in pairs(pool.species) do
  1208.                 for _,genome in pairs(species.genomes) do
  1209.                         total = total + 1
  1210.                         if genome.fitness ~= 0 then
  1211.                                 measured = measured + 1
  1212.                         end
  1213.                 end
  1214.         end
  1215.         if not forms.ischecked(hideBanner) then
  1216.                 gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1217.                 gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
  1218.                 gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1219.         end
  1220.                
  1221.         pool.currentFrame = pool.currentFrame + 1
  1222.  
  1223.         emu.frameadvance();
  1224. end
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top