Advertisement
Guest User

Untitled

a guest
Jun 15th, 2015
538
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 22.10 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- Feel free to use this code, but please do not redistribute it.
  3. -- Intended for use with the BizHawk emulator and Super Mario World.
  4. BasePath= "C:\\Users\\Henrik\\Desktop\\"
  5. Filename = "YI1.state"
  6. ButtonNames = {
  7.     "A",
  8.     "B",
  9.     "X",
  10.     "Y",
  11.     "Up",
  12.     "Down",
  13.     "Left",
  14.     "Right",
  15. }
  16.  
  17. BoxRadius = 6
  18. TileMapInputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  19. TotalInputSize=TileMapInputSize + 6
  20.  
  21. Inputs = TotalInputSize+1
  22. Outputs = #ButtonNames
  23.  
  24. Population = 300
  25. DeltaDisjoint = 2.0
  26. DeltaWeights = 0.4
  27. DeltaThreshold = 1.0
  28.  
  29. StaleSpecies = 15
  30.  
  31. MutateConnectionsChance = 0.25
  32. PerturbChance = 0.90
  33. CrossoverChance = 0.75
  34. LinkMutationChance = 2.0
  35. NodeMutationChance = 0.50
  36. BiasMutationChance = 0.40
  37. StepSize = 0.1
  38. DisableMutationChance = 0.4
  39. EnableMutationChance = 0.2
  40.  
  41. TimeoutConstant = 20
  42.  
  43. MaxNodes = 1000000
  44.  
  45. function getPositions()
  46.     marioX = memory.read_s16_le(0x94)
  47.     marioY = memory.read_s16_le(0x96)
  48.    
  49.     local layer1x = memory.read_s16_le(0x1A);
  50.     local layer1y = memory.read_s16_le(0x1C);
  51.    
  52.     screenX = marioX-layer1x
  53.     screenY = marioY-layer1y
  54. end
  55.  
  56. function getTile(dx, dy)
  57.     x = math.floor((marioX+dx+8)/16)
  58.     y = math.floor((marioY+dy)/16)
  59.        
  60.     return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  61.    
  62. end
  63.  
  64. function getSprites()
  65.     local sprites = {}
  66.     for slot=0,11 do
  67.         local status = memory.readbyte(0x14C8+slot)
  68.         if status ~= 0 then
  69.             spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
  70.             spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
  71.             sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey, ["status"]=status}
  72.         end
  73.     end    
  74.    
  75.     return sprites
  76. end
  77.  
  78. function getExtendedSprites()
  79.     local extended = {}
  80.     for slot=0,11 do
  81.         local number = memory.readbyte(0x170B+slot)
  82.         if number ~= 0 then
  83.             spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
  84.             spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
  85.             extended[#extended+1] = {["x"]=spritex, ["y"]=spritey,["number"]=number}
  86.         end
  87.     end    
  88.    
  89.     return extended
  90. end
  91.  
  92. function getInputs()
  93.     getPositions()
  94.    
  95.     sprites = getSprites()
  96.     extended = getExtendedSprites()
  97.    
  98.     local inputs = {}
  99.    
  100.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  101.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  102.             inputs[#inputs+1] = 0
  103.            
  104.             tile = getTile(dx, dy)
  105.             if tile == 1 and marioY+dy < 0x1B0 then
  106.                 inputs[#inputs] = tile
  107.             end
  108.            
  109.             for i = 1,#sprites do
  110.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  111.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  112.                 if distx <= 8 and disty <= 8 and sprites[i]["status"] then
  113.                     inputs[#inputs] = -sprites[i]["status"]
  114.                 end
  115.             end
  116.  
  117.             for i = 1,#extended do
  118.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  119.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  120.                 if distx < 8 and disty < 8 and extended[i]["number"] then
  121.                     inputs[#inputs] = -extended[i]["number"]
  122.                 end
  123.             end
  124.         end
  125.     end
  126.    
  127.     mariovx = memory.read_s8(0x7B)
  128.     mariovy = memory.read_s8(0x7D)
  129.     pmeter= memory.read_u8(0x13E4)
  130.     takeoffmeter= memory.read_u8(0x149F)
  131.     powerup=memory.read_u8(0x19)
  132.     balloontimer=memory.read_u8(0x1891)
  133.    
  134.     inputs[TileMapInputSize+1] = mariovx
  135.     inputs[TileMapInputSize+2] = mariovy
  136.     inputs[TileMapInputSize+3] = pmeter
  137.     inputs[TileMapInputSize+4] = takeoffmeter
  138.     inputs[TileMapInputSize+5] = powerup
  139.     inputs[TileMapInputSize+6] = balloontimer
  140.    
  141.     return inputs
  142. end
  143.  
  144. function sigmoid(x)
  145.     return 2/(1+math.exp(-4.9*x))-1
  146. end
  147.  
  148. function newInnovation()
  149.     pool.innovation = pool.innovation + 1
  150.     return pool.innovation
  151. end
  152.  
  153. function newPool()
  154.     local pool = {}
  155.     pool.species = {}
  156.     pool.generation = 0
  157.     pool.innovation = Outputs
  158.     pool.currentSpecies = 1
  159.     pool.currentGenome = 1
  160.     pool.currentFrame = 0
  161.     pool.maxFitness = 0
  162.    
  163.     return pool
  164. end
  165.  
  166. function newSpecies()
  167.     local species = {}
  168.     species.topFitness = 0
  169.     species.staleness = 0
  170.     species.genomes = {}
  171.     species.averageFitness = 0
  172.    
  173.     return species
  174. end
  175.  
  176. function newGenome()
  177.     local genome = {}
  178.     genome.genes = {}
  179.     genome.fitness = 0
  180.     genome.adjustedFitness = 0
  181.     genome.network = {}
  182.     genome.maxneuron = 0
  183.     genome.globalRank = 0
  184.     genome.mutationRates = {}
  185.     genome.mutationRates["connections"] = MutateConnectionsChance
  186.     genome.mutationRates["link"] = LinkMutationChance
  187.     genome.mutationRates["bias"] = BiasMutationChance
  188.     genome.mutationRates["node"] = NodeMutationChance
  189.     genome.mutationRates["enable"] = EnableMutationChance
  190.     genome.mutationRates["disable"] = DisableMutationChance
  191.     genome.mutationRates["step"] = StepSize
  192.    
  193.     return genome
  194. end
  195.  
  196. function copyGenome(genome)
  197.     local genome2 = newGenome()
  198.     for g=1,#genome.genes do
  199.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  200.     end
  201.     genome2.maxneuron = genome.maxneuron
  202.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  203.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  204.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  205.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  206.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  207.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  208.    
  209.     return genome2
  210. end
  211.  
  212. function basicGenome()
  213.     local genome = newGenome()
  214.     local innovation = 1
  215.  
  216.     genome.maxneuron = Inputs
  217.     mutate(genome)
  218.    
  219.     return genome
  220. end
  221.  
  222. function newGene()
  223.     local gene = {}
  224.     gene.into = 0
  225.     gene.out = 0
  226.     gene.weight = 0.0
  227.     gene.enabled = true
  228.     gene.innovation = 0
  229.    
  230.     return gene
  231. end
  232.  
  233. function copyGene(gene)
  234.     local gene2 = newGene()
  235.     gene2.into = gene.into
  236.     gene2.out = gene.out
  237.     gene2.weight = gene.weight
  238.     gene2.enabled = gene.enabled
  239.     gene2.innovation = gene.innovation
  240.    
  241.     return gene2
  242. end
  243.  
  244. function newNeuron()
  245.     local neuron = {}
  246.     neuron.incoming = {}
  247.     neuron.value = 0.0
  248.    
  249.     return neuron
  250. end
  251.  
  252. function generateNetwork(genome)
  253.     local network = {}
  254.     network.neurons = {}
  255.    
  256.     for i=1,Inputs do
  257.         network.neurons[i] = newNeuron()
  258.     end
  259.    
  260.     for o=1,Outputs do
  261.         network.neurons[MaxNodes+o] = newNeuron()
  262.     end
  263.    
  264.     table.sort(genome.genes, function (a,b)
  265.         return (a.out <b> 0 then
  266.             neuron.value = sigmoid(sum)
  267.         end
  268.     end
  269.    
  270.     local outputs = {}
  271.     for o=1,Outputs do
  272.         local button = "P1 " .. ButtonNames[o]
  273.         if network.neurons[MaxNodes+o].value > 0 then
  274.             outputs[button] = true
  275.         else
  276.             outputs[button] = false
  277.         end
  278.     end
  279.    
  280.     return outputs
  281. end
  282.  
  283. function crossover(g1, g2)
  284.     -- Make sure g1 is the higher fitness genome
  285.     if g2.fitness > g1.fitness then
  286.         tempg = g1
  287.         g1 = g2
  288.         g2 = tempg
  289.     end
  290.  
  291.     local child = newGenome()
  292.    
  293.     local innovations2 = {}
  294.     for i=1,#g2.genes do
  295.         local gene = g2.genes[i]
  296.         innovations2[gene.innovation] = gene
  297.     end
  298.    
  299.     for i=1,#g1.genes do
  300.         local gene1 = g1.genes[i]
  301.         local gene2 = innovations2[gene1.innovation]
  302.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  303.             table.insert(child.genes, copyGene(gene2))
  304.         else
  305.             table.insert(child.genes, copyGene(gene1))
  306.         end
  307.     end
  308.    
  309.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  310.    
  311.     for mutation,rate in pairs(g1.mutationRates) do
  312.         child.mutationRates[mutation] = rate
  313.     end
  314.    
  315.     return child
  316. end
  317.  
  318. function randomNeuron(genes, nonInput)
  319.     local neurons = {}
  320.     if not nonInput then
  321.         for i=1,Inputs do
  322.             neurons[i] = true
  323.         end
  324.     end
  325.     for o=1,Outputs do
  326.         neurons[MaxNodes+o] = true
  327.     end
  328.     for i=1,#genes do
  329.         if (not nonInput) or genes[i].into > Inputs then
  330.             neurons[genes[i].into] = true
  331.         end
  332.         if (not nonInput) or genes[i].out > Inputs then
  333.             neurons[genes[i].out] = true
  334.         end
  335.     end
  336.  
  337.     local count = 0
  338.     for _,_ in pairs(neurons) do
  339.         count = count + 1
  340.     end
  341.     local n = math.random(1, count)
  342.    
  343.     for k,v in pairs(neurons) do
  344.         n = n-1
  345.         if n == 0 then
  346.             return k
  347.         end
  348.     end
  349.    
  350.     return 0
  351. end
  352.  
  353. function containsLink(genes, link)
  354.     for i=1,#genes do
  355.         local gene = genes[i]
  356.         if gene.into == link.into and gene.out == link.out then
  357.             return true
  358.         end
  359.     end
  360. end
  361.  
  362. function pointMutate(genome)
  363.     local step = genome.mutationRates["step"]
  364.    
  365.     for i=1,#genome.genes do
  366.         local gene = genome.genes[i]
  367.         if math.random() < PerturbChance then
  368.             gene.weight = gene.weight + math.random() * step*2 - step
  369.         else
  370.             gene.weight = math.random()*4-2
  371.         end
  372.     end
  373. end
  374.  
  375. function linkMutate(genome, forceBias)
  376.     local neuron1 = randomNeuron(genome.genes, false)
  377.     local neuron2 = randomNeuron(genome.genes, true)
  378.      
  379.     local newLink = newGene()
  380.     if neuron1 <= Inputs and neuron2 <= Inputs then
  381.         --Both input nodes
  382.         return
  383.     end
  384.     if neuron2 <= Inputs then
  385.         -- Swap output and input
  386.         local temp = neuron1
  387.         neuron1 = neuron2
  388.         neuron2 = temp
  389.     end
  390.  
  391.     newLink.into = neuron1
  392.     newLink.out = neuron2
  393.     if forceBias then
  394.         newLink.into = Inputs
  395.     end
  396.    
  397.     if containsLink(genome.genes, newLink) then
  398.         return
  399.     end
  400.     newLink.innovation = newInnovation()
  401.     newLink.weight = math.random()*4-2
  402.    
  403.     table.insert(genome.genes, newLink)
  404. end
  405.  
  406. function nodeMutate(genome)
  407.     if #genome.genes == 0 then
  408.         return
  409.     end
  410.  
  411.     genome.maxneuron = genome.maxneuron + 1
  412.  
  413.     local gene = genome.genes[math.random(1,#genome.genes)]
  414.     if not gene.enabled then
  415.         return
  416.     end
  417.     gene.enabled = false
  418.    
  419.     local gene1 = copyGene(gene)
  420.     gene1.out = genome.maxneuron
  421.     gene1.weight = 1.0
  422.     gene1.innovation = newInnovation()
  423.     gene1.enabled = true
  424.     table.insert(genome.genes, gene1)
  425.    
  426.     local gene2 = copyGene(gene)
  427.     gene2.into = genome.maxneuron
  428.     gene2.innovation = newInnovation()
  429.     gene2.enabled = true
  430.     table.insert(genome.genes, gene2)
  431. end
  432.  
  433. function enableDisableMutate(genome, enable)
  434.     local candidates = {}
  435.     for _,gene in pairs(genome.genes) do
  436.         if gene.enabled == not enable then
  437.             table.insert(candidates, gene)
  438.         end
  439.     end
  440.    
  441.     if #candidates == 0 then
  442.         return
  443.     end
  444.    
  445.     local gene = candidates[math.random(1,#candidates)]
  446.     gene.enabled = not gene.enabled
  447. end
  448.  
  449. function mutate(genome)
  450.     for mutation,rate in pairs(genome.mutationRates) do
  451.         if math.random(1,2) == 1 then
  452.             genome.mutationRates[mutation] = 0.95*rate
  453.         else
  454.             genome.mutationRates[mutation] = 1.05263*rate
  455.         end
  456.     end
  457.  
  458.     if math.random() <genome> 0 do
  459.         if math.random() <p> 0 do
  460.         if math.random() <p> 0 do
  461.         if math.random() <p> 0 do
  462.         if math.random() <p> 0 do
  463.         if math.random() < p then
  464.             enableDisableMutate(genome, false)
  465.         end
  466.         p = p - 1
  467.     end
  468. end
  469.  
  470. function disjoint(genes1, genes2)
  471.     local i1 = {}
  472.     for i = 1,#genes1 do
  473.         local gene = genes1[i]
  474.         i1[gene.innovation] = true
  475.     end
  476.  
  477.     local i2 = {}
  478.     for i = 1,#genes2 do
  479.         local gene = genes2[i]
  480.         i2[gene.innovation] = true
  481.     end
  482.    
  483.     local disjointGenes = 0
  484.     for i = 1,#genes1 do
  485.         local gene = genes1[i]
  486.         if not i2[gene.innovation] then
  487.             disjointGenes = disjointGenes+1
  488.         end
  489.     end
  490.    
  491.     for i = 1,#genes2 do
  492.         local gene = genes2[i]
  493.         if not i1[gene.innovation] then
  494.             disjointGenes = disjointGenes+1
  495.         end
  496.     end
  497.    
  498.     local n = math.max(#genes1, #genes2)
  499.    
  500.     return disjointGenes / n
  501. end
  502.  
  503. function weights(genes1, genes2)
  504.     local i2 = {}
  505.     for i = 1,#genes2 do
  506.         local gene = genes2[i]
  507.         i2[gene.innovation] = gene
  508.     end
  509.  
  510.     local sum = 0
  511.     local coincident = 0
  512.     for i = 1,#genes1 do
  513.         local gene = genes1[i]
  514.         if i2[gene.innovation] ~= nil then
  515.             local gene2 = i2[gene.innovation]
  516.             sum = sum + math.abs(gene.weight - gene2.weight)
  517.             coincident = coincident + 1
  518.         end
  519.     end
  520.    
  521.     return sum / coincident
  522. end
  523.    
  524. function sameSpecies(genome1, genome2)
  525.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  526.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  527.     return dd + dw < DeltaThreshold
  528. end
  529.  
  530. function rankGlobally()
  531.     local global = {}
  532.     for s = 1,#pool.species do
  533.         local species = pool.species[s]
  534.         for g = 1,#species.genomes do
  535.             table.insert(global, species.genomes[g])
  536.         end
  537.     end
  538.     table.sort(global, function (a,b)
  539.         return (a.fitness <b> b.fitness)
  540.         end)
  541.        
  542.         local remaining = math.ceil(#species.genomes/2)
  543.         if cutToOne then
  544.             remaining = 1
  545.         end
  546.         while #species.genomes > remaining do
  547.             table.remove(species.genomes)
  548.         end
  549.     end
  550. end
  551.  
  552. function breedChild(species)
  553.     local child = {}
  554.     if math.random() <CrossoverChance> b.fitness)
  555.         end)
  556.        
  557.         if species.genomes[1].fitness > species.topFitness then
  558.             species.topFitness = species.genomes[1].fitness
  559.             species.staleness = 0
  560.         else
  561.             species.staleness = species.staleness + 1
  562.         end
  563.         if species.staleness <StaleSpecies>= pool.maxFitness then
  564.             table.insert(survived, species)
  565.         end
  566.     end
  567.  
  568.     pool.species = survived
  569. end
  570.  
  571. function removeWeakSpecies()
  572.     local survived = {}
  573.  
  574.     local sum = totalAverageFitness()
  575.     for s = 1,#pool.species do
  576.         local species = pool.species[s]
  577.         breed = math.floor(species.averageFitness / sum * Population)
  578.         if breed >= 1 then
  579.             table.insert(survived, species)
  580.         end
  581.     end
  582.  
  583.     pool.species = survived
  584. end
  585.  
  586.  
  587. function addToSpecies(child)
  588.     local foundSpecies = false
  589.     for s=1,#pool.species do
  590.         local species = pool.species[s]
  591.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  592.             table.insert(species.genomes, child)
  593.             foundSpecies = true
  594.         end
  595.     end
  596.    
  597.     if not foundSpecies then
  598.         local childSpecies = newSpecies()
  599.         table.insert(childSpecies.genomes, child)
  600.         table.insert(pool.species, childSpecies)
  601.     end
  602. end
  603.  
  604. function newGeneration()
  605.     cullSpecies(false) -- Cull the bottom half of each species
  606.     rankGlobally()
  607.     removeStaleSpecies()
  608.     rankGlobally()
  609.     for s = 1,#pool.species do
  610.         local species = pool.species[s]
  611.         calculateAverageFitness(species)
  612.     end
  613.     removeWeakSpecies()
  614.     local sum = totalAverageFitness()
  615.     local children = {}
  616.     for s = 1,#pool.species do
  617.         local species = pool.species[s]
  618.         breed = math.floor(species.averageFitness / sum * Population) - 1
  619.         for i=1,breed do
  620.             table.insert(children, breedChild(species))
  621.         end
  622.     end
  623.     cullSpecies(true) -- Cull all but the top member of each species
  624.     while #children + #pool.species <Population> #pool.species[pool.currentSpecies].genomes then
  625.         pool.currentGenome = 1
  626.         pool.currentSpecies = pool.currentSpecies+1
  627.         if pool.currentSpecies > #pool.species then
  628.             newGeneration()
  629.             pool.currentSpecies = 1
  630.         end
  631.     end
  632. end
  633.  
  634. function fitnessAlreadyMeasured()
  635.     local species = pool.species[pool.currentSpecies]
  636.     local genome = species.genomes[pool.currentGenome]
  637.    
  638.     return genome.fitness ~= 0
  639. end
  640.  
  641. function displayGenome(genome)
  642.     local network = genome.network
  643.     local cells = {}
  644.     local i = 1
  645.     local cell = {}
  646.     for dy=-BoxRadius,BoxRadius do
  647.         for dx=-BoxRadius,BoxRadius do
  648.             cell = {}
  649.             cell.x = 50+5*dx
  650.             cell.y = 70+5*dy
  651.             cell.value = network.neurons[i].value
  652.             cells[i] = cell
  653.             i = i + 1
  654.         end
  655.     end
  656.     local biasCell = {}
  657.     biasCell.x = 80
  658.     biasCell.y = 110
  659.     biasCell.value = network.neurons[Inputs].value
  660.     cells[Inputs] = biasCell
  661.    
  662.     for o = 1,Outputs do
  663.         cell = {}
  664.         cell.x = 220
  665.         cell.y = 30 + 8 * o
  666.         cell.value = network.neurons[MaxNodes + o].value
  667.         cells[MaxNodes+o] = cell
  668.         local color
  669.         if cell.value > 0 then
  670.             color = 0xFF0000FF
  671.         else
  672.             color = 0xFF000000
  673.         end
  674.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
  675.     end
  676.    
  677.     for n,neuron in pairs(network.neurons) do
  678.         cell = {}
  679.         if n > Inputs and n <MaxNodes> Inputs and gene.into <MaxNodes>= c2.x then
  680.                         c1.x = c1.x - 40
  681.                     end
  682.                     if c1.x <90> 220 then
  683.                         c1.x = 220
  684.                     end
  685.                     c1.y = 0.75*c1.y + 0.25*c2.y
  686.                    
  687.                 end
  688.                 if gene.out > Inputs and gene.out <MaxNodes>= c2.x then
  689.                         c2.x = c2.x + 40
  690.                     end
  691.                     if c2.x <90> 220 then
  692.                         c2.x = 220
  693.                     end
  694.                     c2.y = 0.25*c1.y + 0.75*c2.y
  695.                 end
  696.             end
  697.         end
  698.     end ]]--
  699.    
  700.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
  701.     for n,cell in pairs(cells) do
  702.         if n > Inputs or cell.value ~= 0 then
  703.             local color = math.floor((cell.value+1)/2*256)
  704.             if color > 255 then color = 255 end
  705.             if color <0> 0 then
  706.                     color = opacity + 0x8000 + 0x10000*color
  707.                 else
  708.                     color = opacity + 0x800000 + 0x100*color
  709.                 end
  710.                 gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  711.             end
  712.         end
  713.     end
  714.    
  715.     gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
  716.    
  717.     if forms.ischecked(showMutationRates) then
  718.         local pos = 100
  719.         for mutation,rate in pairs(genome.mutationRates) do
  720.             gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
  721.             pos = pos + 8
  722.         end
  723.     end
  724. end
  725.  
  726. function writeFile(filename)
  727.         local file = io.open(BasePath .. filename, "w")
  728.     file:write(pool.generation .. "\n")
  729.     file:write(pool.maxFitness .. "\n")
  730.     file:write(#pool.species .. "\n")
  731.         for n,species in pairs(pool.species) do
  732.         file:write(species.topFitness .. "\n")
  733.         file:write(species.staleness .. "\n")
  734.         file:write(#species.genomes .. "\n")
  735.         for m,genome in pairs(species.genomes) do
  736.             file:write(genome.fitness .. "\n")
  737.             file:write(genome.maxneuron .. "\n")
  738.             for mutation,rate in pairs(genome.mutationRates) do
  739.                 file:write(mutation .. "\n")
  740.                 file:write(rate .. "\n")
  741.             end
  742.             file:write("done\n")
  743.            
  744.             file:write(#genome.genes .. "\n")
  745.             for l,gene in pairs(genome.genes) do
  746.                 file:write(gene.into .. " ")
  747.                 file:write(gene.out .. " ")
  748.                 file:write(gene.weight .. " ")
  749.                 file:write(gene.innovation .. " ")
  750.                 if(gene.enabled) then
  751.                     file:write("1\n")
  752.                 else
  753.                     file:write("0\n")
  754.                 end
  755.             end
  756.         end
  757.         end
  758.         file:close()
  759. end
  760.  
  761. function savePool()
  762.     local filename = forms.gettext(saveLoadFile)
  763.     writeFile(filename)
  764. end
  765.  
  766. function loadFile(filename)
  767.         local file = io.open(BasePath .. filename, "r")
  768.     pool = newPool()
  769.     pool.generation = file:read("*number")
  770.     pool.maxFitness = file:read("*number")
  771.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  772.         local numSpecies = file:read("*number")
  773.         for s=1,numSpecies do
  774.         local species = newSpecies()
  775.         table.insert(pool.species, species)
  776.         species.topFitness = file:read("*number")
  777.         species.staleness = file:read("*number")
  778.         local numGenomes = file:read("*number")
  779.         for g=1,numGenomes do
  780.             local genome = newGenome()
  781.             table.insert(species.genomes, genome)
  782.             genome.fitness = file:read("*number")
  783.             genome.maxneuron = file:read("*number")
  784.             local line = file:read("*line")
  785.             while line ~= "done" do
  786.                 genome.mutationRates[line] = file:read("*number")
  787.                 line = file:read("*line")
  788.             end
  789.             local numGenes = file:read("*number")
  790.             for n=1,numGenes do
  791.                 local gene = newGene()
  792.                 table.insert(genome.genes, gene)
  793.                 local enabled
  794.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  795.                 if enabled == 0 then
  796.                     gene.enabled = false
  797.                 else
  798.                     gene.enabled = true
  799.                 end
  800.                
  801.             end
  802.         end
  803.     end
  804.         file:close()
  805.    
  806.     while fitnessAlreadyMeasured() do
  807.         nextGenome()
  808.     end
  809.     initializeRun()
  810.     pool.currentFrame = pool.currentFrame + 1
  811. end
  812.  
  813. function loadPool()
  814.     local filename = forms.gettext(saveLoadFile)
  815.     loadFile(filename)
  816. end
  817.  
  818. function playTop()
  819.     local maxfitness = 0
  820.     local maxs, maxg
  821.     for s,species in pairs(pool.species) do
  822.         for g,genome in pairs(species.genomes) do
  823.             if genome.fitness > maxfitness then
  824.                 maxfitness = genome.fitness
  825.                 maxs = s
  826.                 maxg = g
  827.             end
  828.         end
  829.     end
  830.    
  831.     pool.currentSpecies = maxs
  832.     pool.currentGenome = maxg
  833.     pool.maxFitness = maxfitness
  834.     forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  835.     initializeRun()
  836.     pool.currentFrame = pool.currentFrame + 1
  837.     return
  838. end
  839.  
  840. function onExit()
  841.     forms.destroy(form)
  842. end
  843.  
  844. writeFile("temp.pool")
  845.  
  846. event.onexit(onExit)
  847.  
  848. form = forms.newform(200, 260, "Fitness")
  849. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  850. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  851. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  852. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  853. saveButton = forms.button(form, "Save", savePool, 5, 102)
  854. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  855. saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
  856. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  857. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  858. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  859.  
  860.  
  861. while true do
  862.     local backgroundColor = 0xD0FFFFFF
  863.     if not forms.ischecked(hideBanner) then
  864.         gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  865.     end
  866.  
  867.     local species = pool.species[pool.currentSpecies]
  868.     local genome = species.genomes[pool.currentGenome]
  869.    
  870.     if forms.ischecked(showNetwork) then
  871.         displayGenome(genome)
  872.     end
  873.    
  874.     if pool.currentFrame%5 == 0 then
  875.         evaluateCurrent()
  876.     end
  877.  
  878.     joypad.set(controller)
  879.  
  880.     getPositions()
  881.     if marioX > rightmost then
  882.         rightmost = marioX
  883.         timeout = TimeoutConstant
  884.     end
  885.    
  886.     timeout = timeout - 1
  887.    
  888.    
  889.     local timeoutBonus = pool.currentFrame / 4
  890.     if timeout + timeoutBonus <0> 4816 then
  891.             fitness = fitness + 1000
  892.         end
  893.         if fitness == 0 then
  894.             fitness = -1
  895.         end
  896.         genome.fitness = fitness
  897.        
  898.         if fitness > pool.maxFitness then
  899.             pool.maxFitness = fitness
  900.             forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  901.             writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
  902.         end
  903.        
  904.         console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  905.         pool.currentSpecies = 1
  906.         pool.currentGenome = 1
  907.         while fitnessAlreadyMeasured() do
  908.             nextGenome()
  909.         end
  910.         initializeRun()
  911.     end
  912.  
  913.     local measured = 0
  914.     local total = 0
  915.     for _,species in pairs(pool.species) do
  916.         for _,genome in pairs(species.genomes) do
  917.             total = total + 1
  918.             if genome.fitness ~= 0 then
  919.                 measured = measured + 1
  920.             end
  921.         end
  922.     end
  923.     if not forms.ischecked(hideBanner) then
  924.         gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
  925.         gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
  926.         gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
  927.     end
  928.        
  929.     pool.currentFrame = pool.currentFrame + 1
  930.  
  931.     emu.frameadvance();
  932. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement