Advertisement
Guest User

Untitled

a guest
Jul 16th, 2018
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 33.74 KB | None | 0 0
  1. if gameinfo.getromname() == "Super Mario World (USA)" then
  2.     Filename = "DP1.state"
  3.     ButtonNames = {
  4.         "A",
  5.         "B",
  6.         "X",
  7.         "Y",
  8.         "Up",
  9.         "Down",
  10.         "Left",
  11.         "Right",
  12.     }
  13. elseif gameinfo.getromname() == "Super Mario Bros." then
  14.     Filename = "SMB1-1.state"
  15.     ButtonNames = {
  16.         "A",
  17.         "B",
  18.         "Up",
  19.         "Down",
  20.         "Left",
  21.         "Right",
  22.     }
  23. end
  24.  
  25. BoxRadius = 6
  26. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  27.  
  28. Inputs = InputSize+1
  29. Outputs = #ButtonNames
  30.  
  31. Population = 300
  32. DeltaDisjoint = 2.0
  33. DeltaWeights = 0.4
  34. DeltaThreshold = 1.0
  35.  
  36. StaleSpecies = 15
  37.  
  38. MutateConnectionsChance = 0.25
  39. PerturbChance = 0.90
  40. CrossoverChance = 0.75
  41. LinkMutationChance = 2.0
  42. NodeMutationChance = 0.50
  43. BiasMutationChance = 0.40
  44. StepSize = 0.1
  45. DisableMutationChance = 0.4
  46. EnableMutationChance = 0.2
  47.  
  48. TimeoutConstant = 20
  49.  
  50. MaxNodes = 1000000
  51.  
  52. function getPositions()
  53.     if gameinfo.getromname() == "Super Mario World (USA)" then
  54.         marioX = memory.read_s16_le(0x94)
  55.         marioY = memory.read_s16_le(0x96)
  56.        
  57.         local layer1x = memory.read_s16_le(0x1A);
  58.         local layer1y = memory.read_s16_le(0x1C);
  59.        
  60.         screenX = marioX-layer1x
  61.         screenY = marioY-layer1y
  62.     elseif gameinfo.getromname() == "Super Mario Bros." then
  63.         marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
  64.         marioY = memory.readbyte(0x03B8)+16
  65.    
  66.         screenX = memory.readbyte(0x03AD)
  67.         screenY = memory.readbyte(0x03B8)
  68.     end
  69. end
  70.  
  71. function getTile(dx, dy)
  72.     if gameinfo.getromname() == "Super Mario World (USA)" then
  73.         x = math.floor((marioX+dx+8)/16)
  74.         y = math.floor((marioY+dy)/16)
  75.        
  76.         return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  77.     elseif gameinfo.getromname() == "Super Mario Bros." then
  78.         local x = marioX + dx + 8
  79.         local y = marioY + dy - 16
  80.         local page = math.floor(x/256)%2
  81.  
  82.         local subx = math.floor((x%256)/16)
  83.         local suby = math.floor((y - 32)/16)
  84.         local addr = 0x500 + page*13*16+suby*16+subx
  85.        
  86.         if suby >= 13 or suby < 0 then
  87.             return 0
  88.         end
  89.        
  90.         if memory.readbyte(addr) ~= 0 then
  91.             return 1
  92.         else
  93.             return 0
  94.         end
  95.     end
  96. end
  97.  
  98. function getSprites()
  99.     if gameinfo.getromname() == "Super Mario World (USA)" then
  100.         local sprites = {}
  101.         for slot=0,11 do
  102.             local status = memory.readbyte(0x14C8+slot)
  103.             if status ~= 0 then
  104.                 spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  105.                 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  106.                 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  107.             end
  108.         end    
  109.        
  110.         return sprites
  111.     elseif gameinfo.getromname() == "Super Mario Bros." then
  112.         local sprites = {}
  113.         for slot=0,4 do
  114.             local enemy = memory.readbyte(0xF+slot)
  115.             if enemy ~= 0 then
  116.                 local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
  117.                 local ey = memory.readbyte(0xCF + slot)+24
  118.                 sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
  119.             end
  120.         end
  121.        
  122.         return sprites
  123.     end
  124. end
  125.  
  126. function getExtendedSprites()
  127.     if gameinfo.getromname() == "Super Mario World (USA)" then
  128.         local extended = {}
  129.         for slot=0,11 do
  130.             local number = memory.readbyte(0x170B+slot)
  131.             if number ~= 0 then
  132.                 spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  133.                 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  134.                 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  135.             end
  136.         end    
  137.        
  138.         return extended
  139.     elseif gameinfo.getromname() == "Super Mario Bros." then
  140.         return {}
  141.     end
  142. end
  143.  
  144. function getInputs()
  145.     getPositions()
  146.    
  147.     sprites = getSprites()
  148.     extended = getExtendedSprites()
  149.    
  150.     local inputs = {}
  151.    
  152.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  153.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  154.             inputs[#inputs+1] = 0
  155.            
  156.             tile = getTile(dx, dy)
  157.             if tile == 1 and marioY+dy < 0x1B0 then
  158.                 inputs[#inputs] = 1
  159.             end
  160.            
  161.             for i = 1,#sprites do
  162.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  163.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  164.                 if distx <= 8 and disty <= 8 then
  165.                     inputs[#inputs] = -1
  166.                 end
  167.             end
  168.  
  169.             for i = 1,#extended do
  170.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  171.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  172.                 if distx < 8 and disty < 8 then
  173.                     inputs[#inputs] = -1
  174.                 end
  175.             end
  176.         end
  177.     end
  178.    
  179.     --mariovx = memory.read_s8(0x7B)
  180.     --mariovy = memory.read_s8(0x7D)
  181.    
  182.     return inputs
  183. end
  184.  
  185. function sigmoid(x)
  186.     return 2/(1+math.exp(-4.9*x))-1
  187. end
  188.  
  189. function newInnovation()
  190.     pool.innovation = pool.innovation + 1
  191.     return pool.innovation
  192. end
  193.  
  194. function newPool()
  195.     local pool = {}
  196.     pool.species = {}
  197.     pool.generation = 0
  198.     pool.innovation = Outputs
  199.     pool.currentSpecies = 1
  200.     pool.currentGenome = 1
  201.     pool.currentFrame = 0
  202.     pool.maxFitness = 0
  203.    
  204.     return pool
  205. end
  206.  
  207. function newSpecies()
  208.     local species = {}
  209.     species.topFitness = 0
  210.     species.staleness = 0
  211.     species.genomes = {}
  212.     species.averageFitness = 0
  213.    
  214.     return species
  215. end
  216.  
  217. function newGenome()
  218.     local genome = {}
  219.     genome.genes = {}
  220.     genome.fitness = 0
  221.     genome.adjustedFitness = 0
  222.     genome.network = {}
  223.     genome.maxneuron = 0
  224.     genome.globalRank = 0
  225.     genome.mutationRates = {}
  226.     genome.mutationRates["connections"] = MutateConnectionsChance
  227.     genome.mutationRates["link"] = LinkMutationChance
  228.     genome.mutationRates["bias"] = BiasMutationChance
  229.     genome.mutationRates["node"] = NodeMutationChance
  230.     genome.mutationRates["enable"] = EnableMutationChance
  231.     genome.mutationRates["disable"] = DisableMutationChance
  232.     genome.mutationRates["step"] = StepSize
  233.    
  234.     return genome
  235. end
  236.  
  237. function copyGenome(genome)
  238.     local genome2 = newGenome()
  239.     for g=1,#genome.genes do
  240.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  241.     end
  242.     genome2.maxneuron = genome.maxneuron
  243.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  244.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  245.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  246.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  247.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  248.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  249.    
  250.     return genome2
  251. end
  252.  
  253. function basicGenome()
  254.     local genome = newGenome()
  255.     local innovation = 1
  256.  
  257.     genome.maxneuron = Inputs
  258.     mutate(genome)
  259.    
  260.     return genome
  261. end
  262.  
  263. function newGene()
  264.     local gene = {}
  265.     gene.into = 0
  266.     gene.out = 0
  267.     gene.weight = 0.0
  268.     gene.enabled = true
  269.     gene.innovation = 0
  270.    
  271.     return gene
  272. end
  273.  
  274. function copyGene(gene)
  275.     local gene2 = newGene()
  276.     gene2.into = gene.into
  277.     gene2.out = gene.out
  278.     gene2.weight = gene.weight
  279.     gene2.enabled = gene.enabled
  280.     gene2.innovation = gene.innovation
  281.    
  282.     return gene2
  283. end
  284.  
  285. function newNeuron()
  286.     local neuron = {}
  287.     neuron.incoming = {}
  288.     neuron.value = 0.0
  289.    
  290.     return neuron
  291. end
  292.  
  293. function generateNetwork(genome)
  294.     local network = {}
  295.     network.neurons = {}
  296.    
  297.     for i=1,Inputs do
  298.         network.neurons[i] = newNeuron()
  299.     end
  300.    
  301.     for o=1,Outputs do
  302.         network.neurons[MaxNodes+o] = newNeuron()
  303.     end
  304.    
  305.     table.sort(genome.genes, function (a,b)
  306.         return (a.out < b.out)
  307.     end)
  308.     for i=1,#genome.genes do
  309.         local gene = genome.genes[i]
  310.         if gene.enabled then
  311.             if network.neurons[gene.out] == nil then
  312.                 network.neurons[gene.out] = newNeuron()
  313.             end
  314.             local neuron = network.neurons[gene.out]
  315.             table.insert(neuron.incoming, gene)
  316.             if network.neurons[gene.into] == nil then
  317.                 network.neurons[gene.into] = newNeuron()
  318.             end
  319.         end
  320.     end
  321.    
  322.     genome.network = network
  323. end
  324.  
  325. function evaluateNetwork(network, inputs)
  326.     table.insert(inputs, 1)
  327.     if #inputs ~= Inputs then
  328.         console.writeline("Incorrect number of neural network inputs.")
  329.         return {}
  330.     end
  331.    
  332.     for i=1,Inputs do
  333.         network.neurons[i].value = inputs[i]
  334.     end
  335.    
  336.     for _,neuron in pairs(network.neurons) do
  337.         local sum = 0
  338.         for j = 1,#neuron.incoming do
  339.             local incoming = neuron.incoming[j]
  340.             local other = network.neurons[incoming.into]
  341.             sum = sum + incoming.weight * other.value
  342.         end
  343.        
  344.         if #neuron.incoming > 0 then
  345.             neuron.value = sigmoid(sum)
  346.         end
  347.     end
  348.    
  349.     local outputs = {}
  350.     for o=1,Outputs do
  351.         local button = "P1 " .. ButtonNames[o]
  352.         if network.neurons[MaxNodes+o].value > 0 then
  353.             outputs[button] = true
  354.         else
  355.             outputs[button] = false
  356.         end
  357.     end
  358.    
  359.     return outputs
  360. end
  361.  
  362. function crossover(g1, g2)
  363.     -- Make sure g1 is the higher fitness genome
  364.     if g2.fitness > g1.fitness then
  365.         tempg = g1
  366.         g1 = g2
  367.         g2 = tempg
  368.     end
  369.  
  370.     local child = newGenome()
  371.    
  372.     local innovations2 = {}
  373.     for i=1,#g2.genes do
  374.         local gene = g2.genes[i]
  375.         innovations2[gene.innovation] = gene
  376.     end
  377.    
  378.     for i=1,#g1.genes do
  379.         local gene1 = g1.genes[i]
  380.         local gene2 = innovations2[gene1.innovation]
  381.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  382.             table.insert(child.genes, copyGene(gene2))
  383.         else
  384.             table.insert(child.genes, copyGene(gene1))
  385.         end
  386.     end
  387.    
  388.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  389.    
  390.     for mutation,rate in pairs(g1.mutationRates) do
  391.         child.mutationRates[mutation] = rate
  392.     end
  393.    
  394.     return child
  395. end
  396.  
  397. function randomNeuron(genes, nonInput)
  398.     local neurons = {}
  399.     if not nonInput then
  400.         for i=1,Inputs do
  401.             neurons[i] = true
  402.         end
  403.     end
  404.     for o=1,Outputs do
  405.         neurons[MaxNodes+o] = true
  406.     end
  407.     for i=1,#genes do
  408.         if (not nonInput) or genes[i].into > Inputs then
  409.             neurons[genes[i].into] = true
  410.         end
  411.         if (not nonInput) or genes[i].out > Inputs then
  412.             neurons[genes[i].out] = true
  413.         end
  414.     end
  415.  
  416.     local count = 0
  417.     for _,_ in pairs(neurons) do
  418.         count = count + 1
  419.     end
  420.     local n = math.random(1, count)
  421.    
  422.     for k,v in pairs(neurons) do
  423.         n = n-1
  424.         if n == 0 then
  425.             return k
  426.         end
  427.     end
  428.    
  429.     return 0
  430. end
  431.  
  432. function containsLink(genes, link)
  433.     for i=1,#genes do
  434.         local gene = genes[i]
  435.         if gene.into == link.into and gene.out == link.out then
  436.             return true
  437.         end
  438.     end
  439. end
  440.  
  441. function pointMutate(genome)
  442.     local step = genome.mutationRates["step"]
  443.    
  444.     for i=1,#genome.genes do
  445.         local gene = genome.genes[i]
  446.         if math.random() < PerturbChance then
  447.             gene.weight = gene.weight + math.random() * step*2 - step
  448.         else
  449.             gene.weight = math.random()*4-2
  450.         end
  451.     end
  452. end
  453.  
  454. function linkMutate(genome, forceBias)
  455.     local neuron1 = randomNeuron(genome.genes, false)
  456.     local neuron2 = randomNeuron(genome.genes, true)
  457.      
  458.     local newLink = newGene()
  459.     if neuron1 <= Inputs and neuron2 <= Inputs then
  460.         --Both input nodes
  461.         return
  462.     end
  463.     if neuron2 <= Inputs then
  464.         -- Swap output and input
  465.         local temp = neuron1
  466.         neuron1 = neuron2
  467.         neuron2 = temp
  468.     end
  469.  
  470.     newLink.into = neuron1
  471.     newLink.out = neuron2
  472.     if forceBias then
  473.         newLink.into = Inputs
  474.     end
  475.    
  476.     if containsLink(genome.genes, newLink) then
  477.         return
  478.     end
  479.     newLink.innovation = newInnovation()
  480.     newLink.weight = math.random()*4-2
  481.    
  482.     table.insert(genome.genes, newLink)
  483. end
  484.  
  485. function nodeMutate(genome)
  486.     if #genome.genes == 0 then
  487.         return
  488.     end
  489.  
  490.     genome.maxneuron = genome.maxneuron + 1
  491.  
  492.     local gene = genome.genes[math.random(1,#genome.genes)]
  493.     if not gene.enabled then
  494.         return
  495.     end
  496.     gene.enabled = false
  497.    
  498.     local gene1 = copyGene(gene)
  499.     gene1.out = genome.maxneuron
  500.     gene1.weight = 1.0
  501.     gene1.innovation = newInnovation()
  502.     gene1.enabled = true
  503.     table.insert(genome.genes, gene1)
  504.    
  505.     local gene2 = copyGene(gene)
  506.     gene2.into = genome.maxneuron
  507.     gene2.innovation = newInnovation()
  508.     gene2.enabled = true
  509.     table.insert(genome.genes, gene2)
  510. end
  511.  
  512. function enableDisableMutate(genome, enable)
  513.     local candidates = {}
  514.     for _,gene in pairs(genome.genes) do
  515.         if gene.enabled == not enable then
  516.             table.insert(candidates, gene)
  517.         end
  518.     end
  519.    
  520.     if #candidates == 0 then
  521.         return
  522.     end
  523.    
  524.     local gene = candidates[math.random(1,#candidates)]
  525.     gene.enabled = not gene.enabled
  526. end
  527.  
  528. function mutate(genome)
  529.     for mutation,rate in pairs(genome.mutationRates) do
  530.         if math.random(1,2) == 1 then
  531.             genome.mutationRates[mutation] = 0.95*rate
  532.         else
  533.             genome.mutationRates[mutation] = 1.05263*rate
  534.         end
  535.     end
  536.  
  537.     if math.random() < genome.mutationRates["connections"] then
  538.         pointMutate(genome)
  539.     end
  540.    
  541.     local p = genome.mutationRates["link"]
  542.     while p > 0 do
  543.         if math.random() < p then
  544.             linkMutate(genome, false)
  545.         end
  546.         p = p - 1
  547.     end
  548.  
  549.     p = genome.mutationRates["bias"]
  550.     while p > 0 do
  551.         if math.random() < p then
  552.             linkMutate(genome, true)
  553.         end
  554.         p = p - 1
  555.     end
  556.    
  557.     p = genome.mutationRates["node"]
  558.     while p > 0 do
  559.         if math.random() < p then
  560.             nodeMutate(genome)
  561.         end
  562.         p = p - 1
  563.     end
  564.    
  565.     p = genome.mutationRates["enable"]
  566.     while p > 0 do
  567.         if math.random() < p then
  568.             enableDisableMutate(genome, true)
  569.         end
  570.         p = p - 1
  571.     end
  572.  
  573.     p = genome.mutationRates["disable"]
  574.     while p > 0 do
  575.         if math.random() < p then
  576.             enableDisableMutate(genome, false)
  577.         end
  578.         p = p - 1
  579.     end
  580. end
  581.  
  582. function disjoint(genes1, genes2)
  583.     local i1 = {}
  584.     for i = 1,#genes1 do
  585.         local gene = genes1[i]
  586.         i1[gene.innovation] = true
  587.     end
  588.  
  589.     local i2 = {}
  590.     for i = 1,#genes2 do
  591.         local gene = genes2[i]
  592.         i2[gene.innovation] = true
  593.     end
  594.    
  595.     local disjointGenes = 0
  596.     for i = 1,#genes1 do
  597.         local gene = genes1[i]
  598.         if not i2[gene.innovation] then
  599.             disjointGenes = disjointGenes+1
  600.         end
  601.     end
  602.    
  603.     for i = 1,#genes2 do
  604.         local gene = genes2[i]
  605.         if not i1[gene.innovation] then
  606.             disjointGenes = disjointGenes+1
  607.         end
  608.     end
  609.    
  610.     local n = math.max(#genes1, #genes2)
  611.    
  612.     return disjointGenes / n
  613. end
  614.  
  615. function weights(genes1, genes2)
  616.     local i2 = {}
  617.     for i = 1,#genes2 do
  618.         local gene = genes2[i]
  619.         i2[gene.innovation] = gene
  620.     end
  621.  
  622.     local sum = 0
  623.     local coincident = 0
  624.     for i = 1,#genes1 do
  625.         local gene = genes1[i]
  626.         if i2[gene.innovation] ~= nil then
  627.             local gene2 = i2[gene.innovation]
  628.             sum = sum + math.abs(gene.weight - gene2.weight)
  629.             coincident = coincident + 1
  630.         end
  631.     end
  632.    
  633.     return sum / coincident
  634. end
  635.    
  636. function sameSpecies(genome1, genome2)
  637.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  638.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  639.     return dd + dw < DeltaThreshold
  640. end
  641.  
  642. function rankGlobally()
  643.     local global = {}
  644.     for s = 1,#pool.species do
  645.         local species = pool.species[s]
  646.         for g = 1,#species.genomes do
  647.             table.insert(global, species.genomes[g])
  648.         end
  649.     end
  650.     table.sort(global, function (a,b)
  651.         return (a.fitness < b.fitness)
  652.     end)
  653.    
  654.     for g=1,#global do
  655.         global[g].globalRank = g
  656.     end
  657. end
  658.  
  659. function calculateAverageFitness(species)
  660.     local total = 0
  661.    
  662.     for g=1,#species.genomes do
  663.         local genome = species.genomes[g]
  664.         total = total + genome.globalRank
  665.     end
  666.    
  667.     species.averageFitness = total / #species.genomes
  668. end
  669.  
  670. function totalAverageFitness()
  671.     local total = 0
  672.     for s = 1,#pool.species do
  673.         local species = pool.species[s]
  674.         total = total + species.averageFitness
  675.     end
  676.  
  677.     return total
  678. end
  679.  
  680. function cullSpecies(cutToOne)
  681.     for s = 1,#pool.species do
  682.         local species = pool.species[s]
  683.        
  684.         table.sort(species.genomes, function (a,b)
  685.             return (a.fitness > b.fitness)
  686.         end)
  687.        
  688.         local remaining = math.ceil(#species.genomes/2)
  689.         if cutToOne then
  690.             remaining = 1
  691.         end
  692.         while #species.genomes > remaining do
  693.             table.remove(species.genomes)
  694.         end
  695.     end
  696. end
  697.  
  698. function breedChild(species)
  699.     local child = {}
  700.     if math.random() < CrossoverChance then
  701.         g1 = species.genomes[math.random(1, #species.genomes)]
  702.         g2 = species.genomes[math.random(1, #species.genomes)]
  703.         child = crossover(g1, g2)
  704.     else
  705.         g = species.genomes[math.random(1, #species.genomes)]
  706.         child = copyGenome(g)
  707.     end
  708.    
  709.     mutate(child)
  710.    
  711.     return child
  712. end
  713.  
  714. function removeStaleSpecies()
  715.     local survived = {}
  716.  
  717.     for s = 1,#pool.species do
  718.         local species = pool.species[s]
  719.        
  720.         table.sort(species.genomes, function (a,b)
  721.             return (a.fitness > b.fitness)
  722.         end)
  723.        
  724.         if species.genomes[1].fitness > species.topFitness then
  725.             species.topFitness = species.genomes[1].fitness
  726.             species.staleness = 0
  727.         else
  728.             species.staleness = species.staleness + 1
  729.         end
  730.         if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  731.             table.insert(survived, species)
  732.         end
  733.     end
  734.  
  735.     pool.species = survived
  736. end
  737.  
  738. function removeWeakSpecies()
  739.     local survived = {}
  740.  
  741.     local sum = totalAverageFitness()
  742.     for s = 1,#pool.species do
  743.         local species = pool.species[s]
  744.         breed = math.floor(species.averageFitness / sum * Population)
  745.         if breed >= 1 then
  746.             table.insert(survived, species)
  747.         end
  748.     end
  749.  
  750.     pool.species = survived
  751. end
  752.  
  753.  
  754. function addToSpecies(child)
  755.     local foundSpecies = false
  756.     for s=1,#pool.species do
  757.         local species = pool.species[s]
  758.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  759.             table.insert(species.genomes, child)
  760.             foundSpecies = true
  761.         end
  762.     end
  763.    
  764.     if not foundSpecies then
  765.         local childSpecies = newSpecies()
  766.         table.insert(childSpecies.genomes, child)
  767.         table.insert(pool.species, childSpecies)
  768.     end
  769. end
  770.  
  771. function newGeneration()
  772.     cullSpecies(false) -- Cull the bottom half of each species
  773.     rankGlobally()
  774.     removeStaleSpecies()
  775.     rankGlobally()
  776.     for s = 1,#pool.species do
  777.         local species = pool.species[s]
  778.         calculateAverageFitness(species)
  779.     end
  780.     removeWeakSpecies()
  781.     local sum = totalAverageFitness()
  782.     local children = {}
  783.     for s = 1,#pool.species do
  784.         local species = pool.species[s]
  785.         breed = math.floor(species.averageFitness / sum * Population) - 1
  786.         for i=1,breed do
  787.             table.insert(children, breedChild(species))
  788.         end
  789.     end
  790.     cullSpecies(true) -- Cull all but the top member of each species
  791.     while #children + #pool.species < Population do
  792.         local species = pool.species[math.random(1, #pool.species)]
  793.         table.insert(children, breedChild(species))
  794.     end
  795.     for c=1,#children do
  796.         local child = children[c]
  797.         addToSpecies(child)
  798.     end
  799.    
  800.     pool.generation = pool.generation + 1
  801.    
  802.     writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  803. end
  804.    
  805. function initializePool()
  806.     pool = newPool()
  807.  
  808.     for i=1,Population do
  809.         basic = basicGenome()
  810.         addToSpecies(basic)
  811.     end
  812.  
  813.     initializeRun()
  814. end
  815.  
  816. function clearJoypad()
  817.     controller = {}
  818.     for b = 1,#ButtonNames do
  819.         controller["P1 " .. ButtonNames[b]] = false
  820.     end
  821.     joypad.set(controller)
  822. end
  823.  
  824. function initializeRun()
  825.     savestate.load(Filename);
  826.     rightmost = 0
  827.     pool.currentFrame = 0
  828.     timeout = TimeoutConstant
  829.     clearJoypad()
  830.    
  831.     local species = pool.species[pool.currentSpecies]
  832.     local genome = species.genomes[pool.currentGenome]
  833.     generateNetwork(genome)
  834.     evaluateCurrent()
  835. end
  836.  
  837. function evaluateCurrent()
  838.     local species = pool.species[pool.currentSpecies]
  839.     local genome = species.genomes[pool.currentGenome]
  840.  
  841.     inputs = getInputs()
  842.     controller = evaluateNetwork(genome.network, inputs)
  843.    
  844.     if controller["P1 Left"] and controller["P1 Right"] then
  845.         controller["P1 Left"] = false
  846.         controller["P1 Right"] = false
  847.     end
  848.     if controller["P1 Up"] and controller["P1 Down"] then
  849.         controller["P1 Up"] = false
  850.         controller["P1 Down"] = false
  851.     end
  852.  
  853.     joypad.set(controller)
  854. end
  855.  
  856. if pool == nil then
  857.     initializePool()
  858. end
  859.  
  860.  
  861. function nextGenome()
  862.     pool.currentGenome = pool.currentGenome + 1
  863.     if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  864.         pool.currentGenome = 1
  865.         pool.currentSpecies = pool.currentSpecies+1
  866.         if pool.currentSpecies > #pool.species then
  867.             newGeneration()
  868.             pool.currentSpecies = 1
  869.         end
  870.     end
  871. end
  872.  
  873. function fitnessAlreadyMeasured()
  874.     local species = pool.species[pool.currentSpecies]
  875.     local genome = species.genomes[pool.currentGenome]
  876.    
  877.     return genome.fitness ~= 0
  878. end
  879.  
  880. function displayGenome(genome)
  881.     local network = genome.network
  882.     local cells = {}
  883.     local i = 1
  884.     local cell = {}
  885.     for dy=-BoxRadius,BoxRadius do
  886.         for dx=-BoxRadius,BoxRadius do
  887.             cell = {}
  888.             cell.x = 50+5*dx
  889.             cell.y = 70+5*dy
  890.             cell.value = network.neurons[i].value
  891.             cells[i] = cell
  892.             i = i + 1
  893.         end
  894.     end
  895.     local biasCell = {}
  896.     biasCell.x = 80
  897.     biasCell.y = 110
  898.     biasCell.value = network.neurons[Inputs].value
  899.     cells[Inputs] = biasCell
  900.    
  901.     for o = 1,Outputs do
  902.         cell = {}
  903.         cell.x = 220
  904.         cell.y = 30 + 8 * o
  905.         cell.value = network.neurons[MaxNodes + o].value
  906.         cells[MaxNodes+o] = cell
  907.         local color
  908.         if cell.value > 0 then
  909.             color = 0xFF0000FF
  910.         else
  911.             color = 0xFF000000
  912.         end
  913.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  914.     end
  915.    
  916.     for n,neuron in pairs(network.neurons) do
  917.         cell = {}
  918.         if n > Inputs and n <= MaxNodes then
  919.             cell.x = 140
  920.             cell.y = 40
  921.             cell.value = neuron.value
  922.             cells[n] = cell
  923.         end
  924.     end
  925.    
  926.     for n=1,4 do
  927.         for _,gene in pairs(genome.genes) do
  928.             if gene.enabled then
  929.                 local c1 = cells[gene.into]
  930.                 local c2 = cells[gene.out]
  931.                 if gene.into > Inputs and gene.into <= MaxNodes then
  932.                     c1.x = 0.75*c1.x + 0.25*c2.x
  933.                     if c1.x >= c2.x then
  934.                         c1.x = c1.x - 40
  935.                     end
  936.                     if c1.x < 90 then
  937.                         c1.x = 90
  938.                     end
  939.                    
  940.                     if c1.x > 220 then
  941.                         c1.x = 220
  942.                     end
  943.                     c1.y = 0.75*c1.y + 0.25*c2.y
  944.                    
  945.                 end
  946.                 if gene.out > Inputs and gene.out <= MaxNodes then
  947.                     c2.x = 0.25*c1.x + 0.75*c2.x
  948.                     if c1.x >= c2.x then
  949.                         c2.x = c2.x + 40
  950.                     end
  951.                     if c2.x < 90 then
  952.                         c2.x = 90
  953.                     end
  954.                     if c2.x > 220 then
  955.                         c2.x = 220
  956.                     end
  957.                     c2.y = 0.25*c1.y + 0.75*c2.y
  958.                 end
  959.             end
  960.         end
  961.     end
  962.    
  963.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  964.     for n,cell in pairs(cells) do
  965.         if n > Inputs or cell.value ~= 0 then
  966.             local color = math.floor((cell.value+1)/2*256)
  967.             if color > 255 then color = 255 end
  968.             if color < 0 then color = 0 end
  969.             local opacity = 0xFF000000
  970.             if cell.value == 0 then
  971.                 opacity = 0x50000000
  972.             end
  973.             color = opacity + color*0x10000 + color*0x100 + color
  974.             gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
  975.         end
  976.     end
  977.     for _,gene in pairs(genome.genes) do
  978.         if gene.enabled then
  979.             local c1 = cells[gene.into]
  980.             local c2 = cells[gene.out]
  981.             local opacity = 0xA0000000
  982.             if c1.value == 0 then
  983.                 opacity = 0x20000000
  984.             end
  985.            
  986.             local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
  987.             if gene.weight > 0 then
  988.                 color = opacity + 0x8000 + 0x10000*color
  989.             else
  990.                 color = opacity + 0x800000 + 0x100*color
  991.             end
  992.             gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  993.         end
  994.     end
  995.    
  996.     gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  997.    
  998.     if forms.ischecked(showMutationRates) then
  999.         local pos = 100
  1000.         for mutation,rate in pairs(genome.mutationRates) do
  1001.             gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  1002.             pos = pos + 8
  1003.         end
  1004.     end
  1005. end
  1006.  
  1007. function writeFile(filename)
  1008.         local file = io.open(filename, "w")
  1009.     file:write(pool.generation .. "\n")
  1010.     file:write(pool.maxFitness .. "\n")
  1011.     file:write(#pool.species .. "\n")
  1012.         for n,species in pairs(pool.species) do
  1013.         file:write(species.topFitness .. "\n")
  1014.         file:write(species.staleness .. "\n")
  1015.         file:write(#species.genomes .. "\n")
  1016.         for m,genome in pairs(species.genomes) do
  1017.             file:write(genome.fitness .. "\n")
  1018.             file:write(genome.maxneuron .. "\n")
  1019.             for mutation,rate in pairs(genome.mutationRates) do
  1020.                 file:write(mutation .. "\n")
  1021.                 file:write(rate .. "\n")
  1022.             end
  1023.             file:write("done\n")
  1024.            
  1025.             file:write(#genome.genes .. "\n")
  1026.             for l,gene in pairs(genome.genes) do
  1027.                 file:write(gene.into .. " ")
  1028.                 file:write(gene.out .. " ")
  1029.                 file:write(gene.weight .. " ")
  1030.                 file:write(gene.innovation .. " ")
  1031.                 if(gene.enabled) then
  1032.                     file:write("1\n")
  1033.                 else
  1034.                     file:write("0\n")
  1035.                 end
  1036.             end
  1037.         end
  1038.         end
  1039.         file:close()
  1040. end
  1041.  
  1042. function savePool()
  1043.     local filename = forms.gettext(saveLoadFile)
  1044.     writeFile(filename)
  1045. end
  1046.  
  1047. function loadFile(filename)
  1048.         local file = io.open(filename, "r")
  1049.     pool = newPool()
  1050.     pool.generation = file:read("*number")
  1051.     pool.maxFitness = file:read("*number")
  1052.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1053.         local numSpecies = file:read("*number")
  1054.         for s=1,numSpecies do
  1055.         local species = newSpecies()
  1056.         table.insert(pool.species, species)
  1057.         species.topFitness = file:read("*number")
  1058.         species.staleness = file:read("*number")
  1059.         local numGenomes = file:read("*number")
  1060.         for g=1,numGenomes do
  1061.             local genome = newGenome()
  1062.             table.insert(species.genomes, genome)
  1063.             genome.fitness = file:read("*number")
  1064.             genome.maxneuron = file:read("*number")
  1065.             local line = file:read("*line")
  1066.             while line ~= "done" do
  1067.                 genome.mutationRates[line] = file:read("*number")
  1068.                 line = file:read("*line")
  1069.             end
  1070.             local numGenes = file:read("*number")
  1071.             for n=1,numGenes do
  1072.                 local gene = newGene()
  1073.                 table.insert(genome.genes, gene)
  1074.                 local enabled
  1075.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1076.                 if enabled == 0 then
  1077.                     gene.enabled = false
  1078.                 else
  1079.                     gene.enabled = true
  1080.                 end
  1081.                
  1082.             end
  1083.         end
  1084.     end
  1085.         file:close()
  1086.    
  1087.     while fitnessAlreadyMeasured() do
  1088.         nextGenome()
  1089.     end
  1090.     initializeRun()
  1091.     pool.currentFrame = pool.currentFrame + 1
  1092. end
  1093.  
  1094. function loadPool()
  1095.     local filename = forms.gettext(saveLoadFile)
  1096.     loadFile(filename)
  1097. end
  1098.  
  1099. function playTop()
  1100.     local maxfitness = 0
  1101.     local maxs, maxg
  1102.     for s,species in pairs(pool.species) do
  1103.         for g,genome in pairs(species.genomes) do
  1104.             if genome.fitness > maxfitness then
  1105.                 maxfitness = genome.fitness
  1106.                 maxs = s
  1107.                 maxg = g
  1108.             end
  1109.         end
  1110.     end
  1111.    
  1112.     pool.currentSpecies = maxs
  1113.     pool.currentGenome = maxg
  1114.     pool.maxFitness = maxfitness
  1115.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1116.     initializeRun()
  1117.     pool.currentFrame = pool.currentFrame + 1
  1118.     return
  1119. end
  1120.  
  1121. function onExit()
  1122.     forms.destroy(form)
  1123. end
  1124.  
  1125. writeFile("temp.pool")
  1126.  
  1127. event.onexit(onExit)
  1128.  
  1129. form = forms.newform(200, 260, "Fitness")
  1130. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1131. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1132. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1133. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1134. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1135. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1136. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  1137. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1138. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1139. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1140.  
  1141.  
  1142. while true do
  1143.     local backgroundColor = 0xD0FFFFFF
  1144.     if not forms.ischecked(hideBanner) then
  1145.         gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1146.     end
  1147.  
  1148.     local species = pool.species[pool.currentSpecies]
  1149.     local genome = species.genomes[pool.currentGenome]
  1150.    
  1151.     if forms.ischecked(showNetwork) then
  1152.         displayGenome(genome)
  1153.     end
  1154.    
  1155.     if pool.currentFrame%5 == 0 then
  1156.         evaluateCurrent()
  1157.     end
  1158.  
  1159.     joypad.set(controller)
  1160.  
  1161.     getPositions()
  1162.     if marioX > rightmost then
  1163.         rightmost = marioX
  1164.         timeout = TimeoutConstant
  1165.     end
  1166.    
  1167.     timeout = timeout - 1
  1168.    
  1169.    
  1170.     local timeoutBonus = pool.currentFrame / 4
  1171.     if timeout + timeoutBonus <= 0 then
  1172.         local fitness = rightmost - pool.currentFrame / 2
  1173.         if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
  1174.             fitness = fitness + 1000
  1175.         end
  1176.         if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
  1177.             fitness = fitness + 1000
  1178.         end
  1179.         if fitness == 0 then
  1180.             fitness = -1
  1181.         end
  1182.         genome.fitness = fitness
  1183.        
  1184.         if fitness > pool.maxFitness then
  1185.             pool.maxFitness = fitness
  1186.             forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1187.             writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  1188.         end
  1189.        
  1190.         console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1191.         pool.currentSpecies = 1
  1192.         pool.currentGenome = 1
  1193.         while fitnessAlreadyMeasured() do
  1194.             nextGenome()
  1195.         end
  1196.         initializeRun()
  1197.     end
  1198.  
  1199.     local measured = 0
  1200.     local total = 0
  1201.     for _,species in pairs(pool.species) do
  1202.         for _,genome in pairs(species.genomes) do
  1203.             total = total + 1
  1204.             if genome.fitness ~= 0 then
  1205.                 measured = measured + 1
  1206.             end
  1207.         end
  1208.     end
  1209.     if not forms.ischecked(hideBanner) then
  1210.         gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  1211.         gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
  1212.         gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  1213.     end
  1214.        
  1215.     pool.currentFrame = pool.currentFrame + 1
  1216.  
  1217.     emu.frameadvance();
  1218. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement