Advertisement
Amaraticando

MarI/O partially adapted for the Snes9x rerecording emulator

Jun 15th, 2015
399
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 28.16 KB | None | 0 0
  1. -- MarI/O by SethBling
  2. -- partially translated to Snes9x 1.43 v18 API by Amaraticando
  3. -- Feel free to use this code, but please do not redistribute it.
  4. -- Intended for use with the Snes9x-rr emulator and Super Mario World ROM.
  5. -- Start this script a bit prior to the desired level, an internal savestate will be created then...
  6.  
  7. -- USER OPTIONS
  8. local showNetwork = true
  9. local showMutationRates = false
  10. local hideBanner = false
  11.  
  12. -- Compatibility
  13. function gui.drawBox(x1, y1, x2, y2, line, fill) gui.box(x1, y1, x2, y2, fill, line) end
  14. gui.drawText = gui.text
  15. gui.drawLine = gui.line
  16. function memory.read_s8(address) return memory.readbytesigned(0x7e0000 + address) end
  17. function memory.read_s16_le(address) return memory.readwordsigned(0x7e0000 + address) end
  18. function joypad.set2(controler) joypad.table = controller end
  19.  
  20. Savestate_object = savestate.create() -- savestate is created on the fly, not a previous "DP1.state"
  21. savestate.save(Savestate_object)
  22.  
  23. ButtonNames = {
  24.     "A",
  25.     "B",
  26.     "X",
  27.     "Y",
  28.     "up",
  29.     "down",
  30.     "left",
  31.     "right",
  32. }
  33.  
  34. BoxRadius = 6
  35. InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
  36.  
  37. Inputs = InputSize+1
  38. Outputs = #ButtonNames
  39.  
  40. Population = 300
  41. DeltaDisjoint = 2.0
  42. DeltaWeights = 0.4
  43. DeltaThreshold = 1.0
  44.  
  45. StaleSpecies = 15
  46.  
  47. MutateConnectionsChance = 0.25
  48. PerturbChance = 0.90
  49. CrossoverChance = 0.75
  50. LinkMutationChance = 2.0
  51. NodeMutationChance = 0.50
  52. BiasMutationChance = 0.40
  53. StepSize = 0.1
  54. DisableMutationChance = 0.4
  55. EnableMutationChance = 0.2
  56.  
  57. TimeoutConstant = 20
  58.  
  59. MaxNodes = 1000000
  60.  
  61. function getPositions()
  62.     marioX = memory.read_s16_le(0x94)
  63.     marioY = memory.read_s16_le(0x96)
  64.    
  65.     local layer1x = memory.read_s16_le(0x1A);
  66.     local layer1y = memory.read_s16_le(0x1C);
  67.    
  68.     screenX = marioX-layer1x
  69.     screenY = marioY-layer1y
  70. end
  71.  
  72. function getTile(dx, dy)
  73.     x = math.floor((marioX+dx+8)/16)
  74.     y = math.floor((marioY+dy)/16)
  75.    
  76.     return memory.readbyte(0x7fC800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
  77. end
  78.  
  79. function getSprites()
  80.     local sprites = {}
  81.     for slot=0,11 do
  82.         local status = memory.readbyte(0x7e14C8+slot)
  83.         if status ~= 0 then
  84.             spritex = memory.readbyte(0x7e00E4+slot) + memory.readbyte(0x7e14E0+slot)*256
  85.             spritey = memory.readbyte(0x7e00D8+slot) + memory.readbyte(0x7e14D4+slot)*256
  86.             sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
  87.         end
  88.     end    
  89.    
  90.     return sprites
  91. end
  92.  
  93. function getExtendedSprites()
  94.     local extended = {}
  95.     for slot=0,11 do
  96.         local number = memory.readbyte(0x7e170B+slot)
  97.         if number ~= 0 then
  98.             spritex = memory.readbyte(0x7e171F+slot) + memory.readbyte(0x7e1733+slot)*256
  99.             spritey = memory.readbyte(0x7e1715+slot) + memory.readbyte(0x7e1729+slot)*256
  100.             extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
  101.         end
  102.     end    
  103.    
  104.     return extended
  105. end
  106.  
  107. function getInputs()  -- not ready
  108.     getPositions()
  109.    
  110.     sprites = getSprites()
  111.     extended = getExtendedSprites()
  112.    
  113.     local inputs = {}
  114.    
  115.     for dy=-BoxRadius*16,BoxRadius*16,16 do
  116.         for dx=-BoxRadius*16,BoxRadius*16,16 do
  117.             inputs[#inputs+1] = 0
  118.            
  119.             tile = getTile(dx, dy)
  120.             if tile == 1 and marioY+dy < 0x1B0 then
  121.                 inputs[#inputs] = 1
  122.             end
  123.            
  124.             for i = 1,#sprites do
  125.                 distx = math.abs(sprites[i]["x"] - (marioX+dx))
  126.                 disty = math.abs(sprites[i]["y"] - (marioY+dy))
  127.                 if distx <= 8 and disty <= 8 then
  128.                     inputs[#inputs] = -1
  129.                 end
  130.             end
  131.  
  132.             for i = 1,#extended do
  133.                 distx = math.abs(extended[i]["x"] - (marioX+dx))
  134.                 disty = math.abs(extended[i]["y"] - (marioY+dy))
  135.                 if distx < 8 and disty < 8 then
  136.                     inputs[#inputs] = -1
  137.                 end
  138.             end
  139.         end
  140.     end
  141.    
  142.     --mariovx = memory.read_s8(0x7B)
  143.     --mariovy = memory.read_s8(0x7D)
  144.    
  145.     return inputs
  146. end
  147.  
  148. function sigmoid(x)
  149.     return 2/(1+math.exp(-4.9*x))-1
  150. end
  151.  
  152. function newInnovation()
  153.     pool.innovation = pool.innovation + 1
  154.     return pool.innovation
  155. end
  156.  
  157. function newPool()
  158.     local pool = {}
  159.     pool.species = {}
  160.     pool.generation = 0
  161.     pool.innovation = Outputs
  162.     pool.currentSpecies = 1
  163.     pool.currentGenome = 1
  164.     pool.currentFrame = 0
  165.     pool.maxFitness = 0
  166.    
  167.     return pool
  168. end
  169.  
  170. function newSpecies()
  171.     local species = {}
  172.     species.topFitness = 0
  173.     species.staleness = 0
  174.     species.genomes = {}
  175.     species.averageFitness = 0
  176.    
  177.     return species
  178. end
  179.  
  180. function newGenome()
  181.     local genome = {}
  182.     genome.genes = {}
  183.     genome.fitness = 0
  184.     genome.adjustedFitness = 0
  185.     genome.network = {}
  186.     genome.maxneuron = 0
  187.     genome.globalRank = 0
  188.     genome.mutationRates = {}
  189.     genome.mutationRates["connections"] = MutateConnectionsChance
  190.     genome.mutationRates["link"] = LinkMutationChance
  191.     genome.mutationRates["bias"] = BiasMutationChance
  192.     genome.mutationRates["node"] = NodeMutationChance
  193.     genome.mutationRates["enable"] = EnableMutationChance
  194.     genome.mutationRates["disable"] = DisableMutationChance
  195.     genome.mutationRates["step"] = StepSize
  196.    
  197.     return genome
  198. end
  199.  
  200. function copyGenome(genome)
  201.     local genome2 = newGenome()
  202.     for g=1,#genome.genes do
  203.         table.insert(genome2.genes, copyGene(genome.genes[g]))
  204.     end
  205.     genome2.maxneuron = genome.maxneuron
  206.     genome2.mutationRates["connections"] = genome.mutationRates["connections"]
  207.     genome2.mutationRates["link"] = genome.mutationRates["link"]
  208.     genome2.mutationRates["bias"] = genome.mutationRates["bias"]
  209.     genome2.mutationRates["node"] = genome.mutationRates["node"]
  210.     genome2.mutationRates["enable"] = genome.mutationRates["enable"]
  211.     genome2.mutationRates["disable"] = genome.mutationRates["disable"]
  212.    
  213.     return genome2
  214. end
  215.  
  216. function basicGenome()
  217.     local genome = newGenome()
  218.     local innovation = 1
  219.  
  220.     genome.maxneuron = Inputs
  221.     mutate(genome)
  222.    
  223.     return genome
  224. end
  225.  
  226. function newGene()
  227.     local gene = {}
  228.     gene.into = 0
  229.     gene.out = 0
  230.     gene.weight = 0.0
  231.     gene.enabled = true
  232.     gene.innovation = 0
  233.    
  234.     return gene
  235. end
  236.  
  237. function copyGene(gene)
  238.     local gene2 = newGene()
  239.     gene2.into = gene.into
  240.     gene2.out = gene.out
  241.     gene2.weight = gene.weight
  242.     gene2.enabled = gene.enabled
  243.     gene2.innovation = gene.innovation
  244.    
  245.     return gene2
  246. end
  247.  
  248. function newNeuron()
  249.     local neuron = {}
  250.     neuron.incoming = {}
  251.     neuron.value = 0.0
  252.    
  253.     return neuron
  254. end
  255.  
  256. function generateNetwork(genome)
  257.     local network = {}
  258.     network.neurons = {}
  259.    
  260.     for i=1,Inputs do
  261.         network.neurons[i] = newNeuron()
  262.     end
  263.    
  264.     for o=1,Outputs do
  265.         network.neurons[MaxNodes+o] = newNeuron()
  266.     end
  267.    
  268.     table.sort(genome.genes, function (a,b)
  269.         return (a.out < b.out)
  270.     end)
  271.     for i=1,#genome.genes do
  272.         local gene = genome.genes[i]
  273.         if gene.enabled then
  274.             if network.neurons[gene.out] == nil then
  275.                 network.neurons[gene.out] = newNeuron()
  276.             end
  277.             local neuron = network.neurons[gene.out]
  278.             table.insert(neuron.incoming, gene)
  279.             if network.neurons[gene.into] == nil then
  280.                 network.neurons[gene.into] = newNeuron()
  281.             end
  282.         end
  283.     end
  284.    
  285.     genome.network = network
  286. end
  287.  
  288. function evaluateNetwork(network, inputs)
  289.     table.insert(inputs, 1)
  290.     if #inputs ~= Inputs then
  291.         print("Incorrect number of neural network inputs.")
  292.         return {}
  293.     end
  294.    
  295.     for i=1,Inputs do
  296.         network.neurons[i].value = inputs[i]
  297.     end
  298.    
  299.     for _,neuron in pairs(network.neurons) do
  300.         local sum = 0
  301.         for j = 1,#neuron.incoming do
  302.             local incoming = neuron.incoming[j]
  303.             local other = network.neurons[incoming.into]
  304.             sum = sum + incoming.weight * other.value
  305.         end
  306.        
  307.         if #neuron.incoming > 0 then
  308.             neuron.value = sigmoid(sum)
  309.         end
  310.     end
  311.    
  312.     local outputs = {}
  313.     for o=1,Outputs do
  314.         local button = "P1 " .. ButtonNames[o]
  315.         if network.neurons[MaxNodes+o].value > 0 then
  316.             outputs[button] = true
  317.         else
  318.             outputs[button] = false
  319.         end
  320.     end
  321.    
  322.     return outputs
  323. end
  324.  
  325. function crossover(g1, g2)
  326.     -- Make sure g1 is the higher fitness genome
  327.     if g2.fitness > g1.fitness then
  328.         tempg = g1
  329.         g1 = g2
  330.         g2 = tempg
  331.     end
  332.  
  333.     local child = newGenome()
  334.    
  335.     local innovations2 = {}
  336.     for i=1,#g2.genes do
  337.         local gene = g2.genes[i]
  338.         innovations2[gene.innovation] = gene
  339.     end
  340.    
  341.     for i=1,#g1.genes do
  342.         local gene1 = g1.genes[i]
  343.         local gene2 = innovations2[gene1.innovation]
  344.         if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
  345.             table.insert(child.genes, copyGene(gene2))
  346.         else
  347.             table.insert(child.genes, copyGene(gene1))
  348.         end
  349.     end
  350.    
  351.     child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
  352.    
  353.     for mutation,rate in pairs(g1.mutationRates) do
  354.         child.mutationRates[mutation] = rate
  355.     end
  356.    
  357.     return child
  358. end
  359.  
  360. function randomNeuron(genes, nonInput)
  361.     local neurons = {}
  362.     if not nonInput then
  363.         for i=1,Inputs do
  364.             neurons[i] = true
  365.         end
  366.     end
  367.     for o=1,Outputs do
  368.         neurons[MaxNodes+o] = true
  369.     end
  370.     for i=1,#genes do
  371.         if (not nonInput) or genes[i].into > Inputs then
  372.             neurons[genes[i].into] = true
  373.         end
  374.         if (not nonInput) or genes[i].out > Inputs then
  375.             neurons[genes[i].out] = true
  376.         end
  377.     end
  378.  
  379.     local count = 0
  380.     for _,_ in pairs(neurons) do
  381.         count = count + 1
  382.     end
  383.     local n = math.random(1, count)
  384.    
  385.     for k,v in pairs(neurons) do
  386.         n = n-1
  387.         if n == 0 then
  388.             return k
  389.         end
  390.     end
  391.    
  392.     return 0
  393. end
  394.  
  395. function containsLink(genes, link)
  396.     for i=1,#genes do
  397.         local gene = genes[i]
  398.         if gene.into == link.into and gene.out == link.out then
  399.             return true
  400.         end
  401.     end
  402. end
  403.  
  404. function pointMutate(genome)
  405.     local step = genome.mutationRates["step"]
  406.    
  407.     for i=1,#genome.genes do
  408.         local gene = genome.genes[i]
  409.         if math.random() < PerturbChance then
  410.             gene.weight = gene.weight + math.random() * step*2 - step
  411.         else
  412.             gene.weight = math.random()*4-2
  413.         end
  414.     end
  415. end
  416.  
  417. function linkMutate(genome, forceBias)
  418.     local neuron1 = randomNeuron(genome.genes, false)
  419.     local neuron2 = randomNeuron(genome.genes, true)
  420.      
  421.     local newLink = newGene()
  422.     if neuron1 <= Inputs and neuron2 <= Inputs then
  423.         --Both input nodes
  424.         return
  425.     end
  426.     if neuron2 <= Inputs then
  427.         -- Swap output and input
  428.         local temp = neuron1
  429.         neuron1 = neuron2
  430.         neuron2 = temp
  431.     end
  432.  
  433.     newLink.into = neuron1
  434.     newLink.out = neuron2
  435.     if forceBias then
  436.         newLink.into = Inputs
  437.     end
  438.    
  439.     if containsLink(genome.genes, newLink) then
  440.         return
  441.     end
  442.     newLink.innovation = newInnovation()
  443.     newLink.weight = math.random()*4-2
  444.    
  445.     table.insert(genome.genes, newLink)
  446. end
  447.  
  448. function nodeMutate(genome)
  449.     if #genome.genes == 0 then
  450.         return
  451.     end
  452.  
  453.     genome.maxneuron = genome.maxneuron + 1
  454.  
  455.     local gene = genome.genes[math.random(1,#genome.genes)]
  456.     if not gene.enabled then
  457.         return
  458.     end
  459.     gene.enabled = false
  460.    
  461.     local gene1 = copyGene(gene)
  462.     gene1.out = genome.maxneuron
  463.     gene1.weight = 1.0
  464.     gene1.innovation = newInnovation()
  465.     gene1.enabled = true
  466.     table.insert(genome.genes, gene1)
  467.    
  468.     local gene2 = copyGene(gene)
  469.     gene2.into = genome.maxneuron
  470.     gene2.innovation = newInnovation()
  471.     gene2.enabled = true
  472.     table.insert(genome.genes, gene2)
  473. end
  474.  
  475. function enableDisableMutate(genome, enable)
  476.     local candidates = {}
  477.     for _,gene in pairs(genome.genes) do
  478.         if gene.enabled == not enable then
  479.             table.insert(candidates, gene)
  480.         end
  481.     end
  482.    
  483.     if #candidates == 0 then
  484.         return
  485.     end
  486.    
  487.     local gene = candidates[math.random(1,#candidates)]
  488.     gene.enabled = not gene.enabled
  489. end
  490.  
  491. function mutate(genome)
  492.     for mutation,rate in pairs(genome.mutationRates) do
  493.         if math.random(1,2) == 1 then
  494.             genome.mutationRates[mutation] = 0.95*rate
  495.         else
  496.             genome.mutationRates[mutation] = 1.05263*rate
  497.         end
  498.     end
  499.  
  500.     if math.random() < genome.mutationRates["connections"] then
  501.         pointMutate(genome)
  502.     end
  503.    
  504.     local p = genome.mutationRates["link"]
  505.     while p > 0 do
  506.         if math.random() < p then
  507.             linkMutate(genome, false)
  508.         end
  509.         p = p - 1
  510.     end
  511.  
  512.     p = genome.mutationRates["bias"]
  513.     while p > 0 do
  514.         if math.random() < p then
  515.             linkMutate(genome, true)
  516.         end
  517.         p = p - 1
  518.     end
  519.    
  520.     p = genome.mutationRates["node"]
  521.     while p > 0 do
  522.         if math.random() < p then
  523.             nodeMutate(genome)
  524.         end
  525.         p = p - 1
  526.     end
  527.    
  528.     p = genome.mutationRates["enable"]
  529.     while p > 0 do
  530.         if math.random() < p then
  531.             enableDisableMutate(genome, true)
  532.         end
  533.         p = p - 1
  534.     end
  535.  
  536.     p = genome.mutationRates["disable"]
  537.     while p > 0 do
  538.         if math.random() < p then
  539.             enableDisableMutate(genome, false)
  540.         end
  541.         p = p - 1
  542.     end
  543. end
  544.  
  545. function disjoint(genes1, genes2)
  546.     local i1 = {}
  547.     for i = 1,#genes1 do
  548.         local gene = genes1[i]
  549.         i1[gene.innovation] = true
  550.     end
  551.  
  552.     local i2 = {}
  553.     for i = 1,#genes2 do
  554.         local gene = genes2[i]
  555.         i2[gene.innovation] = true
  556.     end
  557.    
  558.     local disjointGenes = 0
  559.     for i = 1,#genes1 do
  560.         local gene = genes1[i]
  561.         if not i2[gene.innovation] then
  562.             disjointGenes = disjointGenes+1
  563.         end
  564.     end
  565.    
  566.     for i = 1,#genes2 do
  567.         local gene = genes2[i]
  568.         if not i1[gene.innovation] then
  569.             disjointGenes = disjointGenes+1
  570.         end
  571.     end
  572.    
  573.     local n = math.max(#genes1, #genes2)
  574.    
  575.     return disjointGenes / n
  576. end
  577.  
  578. function weights(genes1, genes2)
  579.     local i2 = {}
  580.     for i = 1,#genes2 do
  581.         local gene = genes2[i]
  582.         i2[gene.innovation] = gene
  583.     end
  584.  
  585.     local sum = 0
  586.     local coincident = 0
  587.     for i = 1,#genes1 do
  588.         local gene = genes1[i]
  589.         if i2[gene.innovation] ~= nil then
  590.             local gene2 = i2[gene.innovation]
  591.             sum = sum + math.abs(gene.weight - gene2.weight)
  592.             coincident = coincident + 1
  593.         end
  594.     end
  595.    
  596.     return sum / coincident
  597. end
  598.    
  599. function sameSpecies(genome1, genome2)
  600.     local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
  601.     local dw = DeltaWeights*weights(genome1.genes, genome2.genes)
  602.     return dd + dw < DeltaThreshold
  603. end
  604.  
  605. function rankGlobally()
  606.     local global = {}
  607.     for s = 1,#pool.species do
  608.         local species = pool.species[s]
  609.         for g = 1,#species.genomes do
  610.             table.insert(global, species.genomes[g])
  611.         end
  612.     end
  613.     table.sort(global, function (a,b)
  614.         return (a.fitness < b.fitness)
  615.     end)
  616.    
  617.     for g=1,#global do
  618.         global[g].globalRank = g
  619.     end
  620. end
  621.  
  622. function calculateAverageFitness(species)
  623.     local total = 0
  624.    
  625.     for g=1,#species.genomes do
  626.         local genome = species.genomes[g]
  627.         total = total + genome.globalRank
  628.     end
  629.    
  630.     species.averageFitness = total / #species.genomes
  631. end
  632.  
  633. function totalAverageFitness()
  634.     local total = 0
  635.     for s = 1,#pool.species do
  636.         local species = pool.species[s]
  637.         total = total + species.averageFitness
  638.     end
  639.  
  640.     return total
  641. end
  642.  
  643. function cullSpecies(cutToOne)
  644.     for s = 1,#pool.species do
  645.         local species = pool.species[s]
  646.        
  647.         table.sort(species.genomes, function (a,b)
  648.             return (a.fitness > b.fitness)
  649.         end)
  650.        
  651.         local remaining = math.ceil(#species.genomes/2)
  652.         if cutToOne then
  653.             remaining = 1
  654.         end
  655.         while #species.genomes > remaining do
  656.             table.remove(species.genomes)
  657.         end
  658.     end
  659. end
  660.  
  661. function breedChild(species)
  662.     local child = {}
  663.     if math.random() < CrossoverChance then
  664.         g1 = species.genomes[math.random(1, #species.genomes)]
  665.         g2 = species.genomes[math.random(1, #species.genomes)]
  666.         child = crossover(g1, g2)
  667.     else
  668.         g = species.genomes[math.random(1, #species.genomes)]
  669.         child = copyGenome(g)
  670.     end
  671.    
  672.     mutate(child)
  673.    
  674.     return child
  675. end
  676.  
  677. function removeStaleSpecies()
  678.     local survived = {}
  679.  
  680.     for s = 1,#pool.species do
  681.         local species = pool.species[s]
  682.        
  683.         table.sort(species.genomes, function (a,b)
  684.             return (a.fitness > b.fitness)
  685.         end)
  686.        
  687.         if species.genomes[1].fitness > species.topFitness then
  688.             species.topFitness = species.genomes[1].fitness
  689.             species.staleness = 0
  690.         else
  691.             species.staleness = species.staleness + 1
  692.         end
  693.         if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
  694.             table.insert(survived, species)
  695.         end
  696.     end
  697.  
  698.     pool.species = survived
  699. end
  700.  
  701. function removeWeakSpecies()
  702.     local survived = {}
  703.  
  704.     local sum = totalAverageFitness()
  705.     for s = 1,#pool.species do
  706.         local species = pool.species[s]
  707.         breed = math.floor(species.averageFitness / sum * Population)
  708.         if breed >= 1 then
  709.             table.insert(survived, species)
  710.         end
  711.     end
  712.  
  713.     pool.species = survived
  714. end
  715.  
  716.  
  717. function addToSpecies(child)
  718.     local foundSpecies = false
  719.     for s=1,#pool.species do
  720.         local species = pool.species[s]
  721.         if not foundSpecies and sameSpecies(child, species.genomes[1]) then
  722.             table.insert(species.genomes, child)
  723.             foundSpecies = true
  724.         end
  725.     end
  726.    
  727.     if not foundSpecies then
  728.         local childSpecies = newSpecies()
  729.         table.insert(childSpecies.genomes, child)
  730.         table.insert(pool.species, childSpecies)
  731.     end
  732. end
  733.  
  734. function newGeneration()
  735.     cullSpecies(false) -- Cull the bottom half of each species
  736.     rankGlobally()
  737.     removeStaleSpecies()
  738.     rankGlobally()
  739.     for s = 1,#pool.species do
  740.         local species = pool.species[s]
  741.         calculateAverageFitness(species)
  742.     end
  743.     removeWeakSpecies()
  744.     local sum = totalAverageFitness()
  745.     local children = {}
  746.     for s = 1,#pool.species do
  747.         local species = pool.species[s]
  748.         breed = math.floor(species.averageFitness / sum * Population) - 1
  749.         for i=1,breed do
  750.             table.insert(children, breedChild(species))
  751.         end
  752.     end
  753.     cullSpecies(true) -- Cull all but the top member of each species
  754.     while #children + #pool.species < Population do
  755.         local species = pool.species[math.random(1, #pool.species)]
  756.         table.insert(children, breedChild(species))
  757.     end
  758.     for c=1,#children do
  759.         local child = children[c]
  760.         addToSpecies(child)
  761.     end
  762.    
  763.     pool.generation = pool.generation + 1
  764.    
  765.     writeFile("backup." .. pool.generation .. "." .. "level.state")
  766. end
  767.    
  768. function initializePool()
  769.     pool = newPool()
  770.  
  771.     for i=1,Population do
  772.         basic = basicGenome()
  773.         addToSpecies(basic)
  774.     end
  775.  
  776.     initializeRun()
  777. end
  778.  
  779. function clearJoypad()
  780.     controller = {}
  781.     for b = 1,#ButtonNames do
  782.         controller["P1 " .. ButtonNames[b]] = false
  783.     end
  784.     joypad.set2(controller)
  785. end
  786.  
  787. function initializeRun()
  788.     savestate.load(Savestate_object); -- Amarat
  789.     rightmost = 0
  790.     pool.currentFrame = 0
  791.     timeout = TimeoutConstant
  792.     clearJoypad()
  793.    
  794.     local species = pool.species[pool.currentSpecies]
  795.     local genome = species.genomes[pool.currentGenome]
  796.     generateNetwork(genome)
  797.     evaluateCurrent()
  798. end
  799.  
  800. function evaluateCurrent()
  801.     local species = pool.species[pool.currentSpecies]
  802.     local genome = species.genomes[pool.currentGenome]
  803.  
  804.     inputs = getInputs()
  805.     controller = evaluateNetwork(genome.network, inputs)
  806.    
  807.     if controller["P1 Left"] and controller["P1 Right"] then
  808.         controller["P1 Left"] = false
  809.         controller["P1 Right"] = false
  810.     end
  811.     if controller["P1 Up"] and controller["P1 Down"] then
  812.         controller["P1 Up"] = false
  813.         controller["P1 Down"] = false
  814.     end
  815.  
  816.     joypad.set2(controller) -- Amarat
  817. end
  818.  
  819. if pool == nil then
  820.     initializePool()
  821. end
  822.  
  823.  
  824. function nextGenome()
  825.     pool.currentGenome = pool.currentGenome + 1
  826.     if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
  827.         pool.currentGenome = 1
  828.         pool.currentSpecies = pool.currentSpecies+1
  829.         if pool.currentSpecies > #pool.species then
  830.             newGeneration()
  831.             pool.currentSpecies = 1
  832.         end
  833.     end
  834. end
  835.  
  836. function fitnessAlreadyMeasured()
  837.     local species = pool.species[pool.currentSpecies]
  838.     local genome = species.genomes[pool.currentGenome]
  839.    
  840.     return genome.fitness ~= 0
  841. end
  842.  
  843. function displayGenome(genome)
  844.     local network = genome.network
  845.     local cells = {}
  846.     local i = 1
  847.     local cell = {}
  848.     for dy=-BoxRadius,BoxRadius do
  849.         for dx=-BoxRadius,BoxRadius do
  850.             cell = {}
  851.             cell.x = 50+5*dx
  852.             cell.y = 70+5*dy
  853.             cell.value = network.neurons[i].value
  854.             cells[i] = cell
  855.             i = i + 1
  856.         end
  857.     end
  858.     local biasCell = {}
  859.     biasCell.x = 80
  860.     biasCell.y = 110
  861.     biasCell.value = network.neurons[Inputs].value
  862.     cells[Inputs] = biasCell
  863.    
  864.     for o = 1,Outputs do
  865.         cell = {}
  866.         cell.x = 220
  867.         cell.y = 30 + 8 * o
  868.         cell.value = network.neurons[MaxNodes + o].value
  869.         cells[MaxNodes+o] = cell
  870.         local color
  871.         if cell.value > 0 then
  872.             color = 0x0000FFFF
  873.         else
  874.             color = 0x000000FF
  875.         end
  876.         gui.drawText(223, 24+8*o, ButtonNames[o], color, 0xffffffff)
  877.     end
  878.    
  879.     for n,neuron in pairs(network.neurons) do
  880.         cell = {}
  881.         if n > Inputs and n <= MaxNodes then
  882.             cell.x = 140
  883.             cell.y = 40
  884.             cell.value = neuron.value
  885.             cells[n] = cell
  886.         end
  887.     end
  888.    
  889.     for n=1,4 do
  890.         for _,gene in pairs(genome.genes) do
  891.             if gene.enabled then
  892.                 local c1 = cells[gene.into]
  893.                 local c2 = cells[gene.out]
  894.                 if gene.into > Inputs and gene.into <= MaxNodes then
  895.                     c1.x = 0.75*c1.x + 0.25*c2.x
  896.                     if c1.x >= c2.x then
  897.                         c1.x = c1.x - 40
  898.                     end
  899.                     if c1.x < 90 then
  900.                         c1.x = 90
  901.                     end
  902.                    
  903.                     if c1.x > 220 then
  904.                         c1.x = 220
  905.                     end
  906.                     c1.y = 0.75*c1.y + 0.25*c2.y
  907.                    
  908.                 end
  909.                 if gene.out > Inputs and gene.out <= MaxNodes then
  910.                     c2.x = 0.25*c1.x + 0.75*c2.x
  911.                     if c1.x >= c2.x then
  912.                         c2.x = c2.x + 40
  913.                     end
  914.                     if c2.x < 90 then
  915.                         c2.x = 90
  916.                     end
  917.                     if c2.x > 220 then
  918.                         c2.x = 220
  919.                     end
  920.                     c2.y = 0.25*c1.y + 0.75*c2.y
  921.                 end
  922.             end
  923.         end
  924.     end
  925.    
  926.     gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0x000000FF, 0x80808080)
  927.     for n,cell in pairs(cells) do
  928.         if n > Inputs or cell.value ~= 0 then
  929.             local color = math.floor((cell.value+1)/2*256)
  930.             if color > 255 then color = 255 end
  931.             if color < 0 then color = 0 end
  932.             local opacity = 0xFF
  933.             if cell.value == 0 then
  934.                 opacity = 0x50
  935.             end
  936.             color = color*0x1000000 + color*0x10000 + color*0x100 + opacity
  937.             gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) -- Amarat: colours not edited
  938.         end
  939.     end
  940.     for _,gene in pairs(genome.genes) do
  941.         if gene.enabled then
  942.             local c1 = cells[gene.into]
  943.             local c2 = cells[gene.out]
  944.             local opacity = 0xA0
  945.             if c1.value == 0 then
  946.                 opacity = 0x20
  947.             end
  948.            
  949.             local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)  -- Amarat: colours not edited
  950.             if gene.weight > 0 then
  951.                 color = opacity + 0x8000 + 0x10000*color
  952.             else
  953.                 color = opacity + 0x800000 + 0x100*color
  954.             end
  955.             gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
  956.         end
  957.     end
  958.    
  959.     gui.drawBox(49,71,51,78,0x00000000,0xFF000080)
  960.    
  961.     -- Without forms :(
  962.     if showMutationRates then
  963.         local pos = 100
  964.         for mutation,rate in pairs(genome.mutationRates) do
  965.             gui.drawText(100, pos, mutation .. ": " .. rate, 0x000000FF, 0xffffffff)
  966.             pos = pos + 8
  967.         end
  968.     end
  969. end
  970.  
  971. function writeFile(filename)
  972.         local file = io.open(filename, "w")
  973.     file:write(pool.generation .. "\n")
  974.     file:write(pool.maxFitness .. "\n")
  975.     file:write(#pool.species .. "\n")
  976.         for n,species in pairs(pool.species) do
  977.         file:write(species.topFitness .. "\n")
  978.         file:write(species.staleness .. "\n")
  979.         file:write(#species.genomes .. "\n")
  980.         for m,genome in pairs(species.genomes) do
  981.             file:write(genome.fitness .. "\n")
  982.             file:write(genome.maxneuron .. "\n")
  983.             for mutation,rate in pairs(genome.mutationRates) do
  984.                 file:write(mutation .. "\n")
  985.                 file:write(rate .. "\n")
  986.             end
  987.             file:write("done\n")
  988.            
  989.             file:write(#genome.genes .. "\n")
  990.             for l,gene in pairs(genome.genes) do
  991.                 file:write(gene.into .. " ")
  992.                 file:write(gene.out .. " ")
  993.                 file:write(gene.weight .. " ")
  994.                 file:write(gene.innovation .. " ")
  995.                 if(gene.enabled) then
  996.                     file:write("1\n")
  997.                 else
  998.                     file:write("0\n")
  999.                 end
  1000.             end
  1001.         end
  1002.         end
  1003.         file:close()
  1004. end
  1005.  
  1006. function savePool()
  1007.     local filename = "level.state"
  1008.     writeFile(filename)
  1009. end
  1010.  
  1011. function loadFile(filename)
  1012.         local file = io.open(filename, "r")
  1013.     pool = newPool()
  1014.     pool.generation = file:read("*number")
  1015.     pool.maxFitness = file:read("*number")
  1016.     --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) -- forms are not possible on Snes9x
  1017.         local numSpecies = file:read("*number")
  1018.         for s=1,numSpecies do
  1019.         local species = newSpecies()
  1020.         table.insert(pool.species, species)
  1021.         species.topFitness = file:read("*number")
  1022.         species.staleness = file:read("*number")
  1023.         local numGenomes = file:read("*number")
  1024.         for g=1,numGenomes do
  1025.             local genome = newGenome()
  1026.             table.insert(species.genomes, genome)
  1027.             genome.fitness = file:read("*number")
  1028.             genome.maxneuron = file:read("*number")
  1029.             local line = file:read("*line")
  1030.             while line ~= "done" do
  1031.                 genome.mutationRates[line] = file:read("*number")
  1032.                 line = file:read("*line")
  1033.             end
  1034.             local numGenes = file:read("*number")
  1035.             for n=1,numGenes do
  1036.                 local gene = newGene()
  1037.                 table.insert(genome.genes, gene)
  1038.                 local enabled
  1039.                 gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
  1040.                 if enabled == 0 then
  1041.                     gene.enabled = false
  1042.                 else
  1043.                     gene.enabled = true
  1044.                 end
  1045.                
  1046.             end
  1047.         end
  1048.     end
  1049.         file:close()
  1050.    
  1051.     while fitnessAlreadyMeasured() do
  1052.         nextGenome()
  1053.     end
  1054.     initializeRun()
  1055.     pool.currentFrame = pool.currentFrame + 1
  1056. end
  1057.  
  1058. function loadPool()
  1059.     local filename = "level.state"
  1060.     loadFile(filename)
  1061. end
  1062.  
  1063. function playTop()
  1064.     local maxfitness = 0
  1065.     local maxs, maxg
  1066.     for s,species in pairs(pool.species) do
  1067.         for g,genome in pairs(species.genomes) do
  1068.             if genome.fitness > maxfitness then
  1069.                 maxfitness = genome.fitness
  1070.                 maxs = s
  1071.                 maxg = g
  1072.             end
  1073.         end
  1074.     end
  1075.    
  1076.     pool.currentSpecies = maxs
  1077.     pool.currentGenome = maxg
  1078.     pool.maxFitness = maxfitness
  1079.     --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1080.     initializeRun()
  1081.     pool.currentFrame = pool.currentFrame + 1
  1082.     return
  1083. end
  1084.  
  1085. function onExit()
  1086.     forms.destroy(form)
  1087. end
  1088.  
  1089. writeFile("temp.pool")
  1090.  
  1091. --event.onexit(onExit) -- Amarat: forms are not possible on Snes9x
  1092.  
  1093. --[[
  1094. form = forms.newform(200, 260, "Fitness")
  1095. maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
  1096. showNetwork = forms.checkbox(form, "Show Map", 5, 30)
  1097. showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
  1098. restartButton = forms.button(form, "Restart", initializePool, 5, 77)
  1099. saveButton = forms.button(form, "Save", savePool, 5, 102)
  1100. loadButton = forms.button(form, "Load", loadPool, 80, 102)
  1101. saveLoadFile = forms.textbox(form, Savestate_object .. ".pool", 170, 25, nil, 5, 148)
  1102. saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
  1103. playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
  1104. hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
  1105. --]]
  1106.  
  1107.  
  1108. emu.registerbefore(function()
  1109.     local _input = {}
  1110.     for key, value in pairs(joypad.table) do
  1111.         _input[string.sub(key, 4)] = value
  1112.     end
  1113.     joypad.set(1, _input)
  1114. end)
  1115.  
  1116.  
  1117. while true do
  1118.     local backgroundColor = 0xFFFFFFD0
  1119.     -- Without forms :(
  1120.     if not hideBanner then
  1121.         gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
  1122.     end
  1123.    
  1124.     local species = pool.species[pool.currentSpecies]
  1125.     local genome = species.genomes[pool.currentGenome]
  1126.    
  1127.     -- Without forms :(
  1128.     if showNetwork then
  1129.         displayGenome(genome)
  1130.     end
  1131.    
  1132.     if pool.currentFrame%5 == 0 then
  1133.         evaluateCurrent()
  1134.     end
  1135.  
  1136.     joypad.set2(controller)
  1137.  
  1138.     getPositions()
  1139.     if marioX > rightmost then
  1140.         rightmost = marioX
  1141.         timeout = TimeoutConstant
  1142.     end
  1143.    
  1144.     timeout = timeout - 1
  1145.    
  1146.    
  1147.     local timeoutBonus = pool.currentFrame / 4
  1148.     if timeout + timeoutBonus <= 0 then
  1149.         local fitness = rightmost - pool.currentFrame / 2
  1150.         fitness = fitness + 1000
  1151.         if fitness == 0 then
  1152.             fitness = -1
  1153.         end
  1154.         genome.fitness = fitness
  1155.        
  1156.         if fitness > pool.maxFitness then
  1157.             pool.maxFitness = fitness
  1158.             --forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
  1159.             writeFile("backup." .. pool.generation .. "." .. "level.state")
  1160.         end
  1161.        
  1162.         print("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
  1163.         pool.currentSpecies = 1
  1164.         pool.currentGenome = 1
  1165.         while fitnessAlreadyMeasured() do
  1166.             nextGenome()
  1167.         end
  1168.         initializeRun()
  1169.     end
  1170.  
  1171.     local measured = 0
  1172.     local total = 0
  1173.     for _,species in pairs(pool.species) do
  1174.         for _,genome in pairs(species.genomes) do
  1175.             total = total + 1
  1176.             if genome.fitness ~= 0 then
  1177.                 measured = measured + 1
  1178.             end
  1179.         end
  1180.     end
  1181.     -- Without forms :(
  1182.     if not hideBanner then
  1183.         gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0x000000FF, 0xffffffff)
  1184.         gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0x000000FF, 0xffffffff)
  1185.         gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0x000000FF, 0xffffffff)
  1186.     end
  1187.    
  1188.     pool.currentFrame = pool.currentFrame + 1
  1189.  
  1190.     emu.frameadvance();
  1191. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement