View difference between Paste ID: MNC9KBSm and ZZmSNaHX
SHOW: | | - or go back to the newest paste.
1
-- MarI/O by SethBling
2
-- Feel free to use this code, but please do not redistribute it.
3
-- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM.
4
-- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level,
5
-- and put a copy in both the Lua folder and the root directory of BizHawk.
6
7
if gameinfo.getromname() == "Super Mario World (USA)" then
8
	Filename = "DP1.state"
9
	ButtonNames = {
10
		"A",
11
		"B",
12
		"X",
13
		"Y",
14
		"Up",
15
		"Down",
16
		"Left",
17
		"Right",
18
	}
19
elseif gameinfo.getromname() == "Super Mario Bros." then
20
	Filename = "SMB1-1.state"
21
	ButtonNames = {
22
		"A",
23
		"B",
24
		"Up",
25
		"Down",
26
		"Left",
27
		"Right",
28
	}
29
end
30
31
BoxRadius = 6
32
InputSize = (BoxRadius*2+1)*(BoxRadius*2+1)
33
34
Inputs = InputSize+1
35
Outputs = #ButtonNames
36
37
Population = 300
38
DeltaDisjoint = 2.0
39
DeltaWeights = 0.4
40
DeltaThreshold = 1.0
41
42
StaleSpecies = 15
43
44
MutateConnectionsChance = 0.25
45
PerturbChance = 0.90
46
CrossoverChance = 0.75
47
LinkMutationChance = 2.0
48
NodeMutationChance = 0.50
49
BiasMutationChance = 0.40
50
StepSize = 0.1
51
DisableMutationChance = 0.4
52
EnableMutationChance = 0.2
53
54
TimeoutConstant = 20
55
56
MaxNodes = 1000000
57
58
function getPositions()
59
	if gameinfo.getromname() == "Super Mario World (USA)" then
60
		marioX = memory.read_s16_le(0x94)
61
		marioY = memory.read_s16_le(0x96)
62
		
63
		local layer1x = memory.read_s16_le(0x1A);
64
		local layer1y = memory.read_s16_le(0x1C);
65
		
66
		screenX = marioX-layer1x
67
		screenY = marioY-layer1y
68
	elseif gameinfo.getromname() == "Super Mario Bros." then
69
		marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86)
70
		marioY = memory.readbyte(0x03B8)+16
71
	
72
		screenX = memory.readbyte(0x03AD)
73
		screenY = memory.readbyte(0x03B8)
74
	end
75
end
76
77
function getTile(dx, dy)
78
	if gameinfo.getromname() == "Super Mario World (USA)" then
79
		x = math.floor((marioX+dx+8)/16)
80
		y = math.floor((marioY+dy)/16)
81
		
82
		return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10)
83
	elseif gameinfo.getromname() == "Super Mario Bros." then
84
		local x = marioX + dx + 8
85
		local y = marioY + dy - 16
86
		local page = math.floor(x/256)%2
87
88
		local subx = math.floor((x%256)/16)
89
		local suby = math.floor((y - 32)/16)
90
		local addr = 0x500 + page*13*16+suby*16+subx
91
		
92
		if suby >= 13 or suby < 0 then
93
			return 0
94
		end
95
		
96
		if memory.readbyte(addr) ~= 0 then
97
			return 1
98
		else
99
			return 0
100
		end
101
	end
102
end
103
104
function getSprites()
105
	if gameinfo.getromname() == "Super Mario World (USA)" then
106
		local sprites = {}
107
		for slot=0,11 do
108
			local status = memory.readbyte(0x14C8+slot)
109
			if status ~= 0 then
110
				spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256
111
				spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256
112
				sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey}
113
			end
114
		end		
115
		
116
		return sprites
117
	elseif gameinfo.getromname() == "Super Mario Bros." then
118
		local sprites = {}
119
		for slot=0,4 do
120
			local enemy = memory.readbyte(0xF+slot)
121
			if enemy ~= 0 then
122
				local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot)
123
				local ey = memory.readbyte(0xCF + slot)+24
124
				sprites[#sprites+1] = {["x"]=ex,["y"]=ey}
125
			end
126
		end
127
		
128
		return sprites
129
	end
130
end
131
132
function getExtendedSprites()
133
	if gameinfo.getromname() == "Super Mario World (USA)" then
134
		local extended = {}
135
		for slot=0,11 do
136
			local number = memory.readbyte(0x170B+slot)
137
			if number ~= 0 then
138
				spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256
139
				spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256
140
				extended[#extended+1] = {["x"]=spritex, ["y"]=spritey}
141
			end
142
		end		
143
		
144
		return extended
145
	elseif gameinfo.getromname() == "Super Mario Bros." then
146
		return {}
147
	end
148
end
149
150
function getInputs()
151
	getPositions()
152
	
153
	sprites = getSprites()
154
	extended = getExtendedSprites()
155
	
156
	local inputs = {}
157
	
158
	for dy=-BoxRadius*16,BoxRadius*16,16 do
159
		for dx=-BoxRadius*16,BoxRadius*16,16 do
160
			inputs[#inputs+1] = 0
161
			
162
			tile = getTile(dx, dy)
163
			if tile == 1 and marioY+dy < 0x1B0 then
164
				inputs[#inputs] = 1
165
			end
166
			
167
			for i = 1,#sprites do
168
				distx = math.abs(sprites[i]["x"] - (marioX+dx))
169
				disty = math.abs(sprites[i]["y"] - (marioY+dy))
170
				if distx <= 8 and disty <= 8 then
171
					inputs[#inputs] = -1
172
				end
173
			end
174
175
			for i = 1,#extended do
176
				distx = math.abs(extended[i]["x"] - (marioX+dx))
177
				disty = math.abs(extended[i]["y"] - (marioY+dy))
178
				if distx < 8 and disty < 8 then
179
					inputs[#inputs] = -1
180
				end
181
			end
182
		end
183
	end
184
	
185
	--mariovx = memory.read_s8(0x7B)
186
	--mariovy = memory.read_s8(0x7D)
187
	
188
	return inputs
189
end
190
191
function sigmoid(x)
192
	return 2/(1+math.exp(-4.9*x))-1
193
end
194
195
function newInnovation()
196
	pool.innovation = pool.innovation + 1
197
	return pool.innovation
198
end
199
200
function newPool()
201
	local pool = {}
202
	pool.species = {}
203
	pool.generation = 0
204
	pool.innovation = Outputs
205
	pool.currentSpecies = 1
206
	pool.currentGenome = 1
207
	pool.currentFrame = 0
208
	pool.maxFitness = 0
209
	
210
	return pool
211
end
212
213
function newSpecies()
214
	local species = {}
215
	species.topFitness = 0
216
	species.staleness = 0
217
	species.genomes = {}
218
	species.averageFitness = 0
219
	
220
	return species
221
end
222
223
function newGenome()
224
	local genome = {}
225
	genome.genes = {}
226
	genome.fitness = 0
227
	genome.adjustedFitness = 0
228
	genome.network = {}
229
	genome.maxneuron = 0
230
	genome.globalRank = 0
231
	genome.mutationRates = {}
232
	genome.mutationRates["connections"] = MutateConnectionsChance
233
	genome.mutationRates["link"] = LinkMutationChance
234
	genome.mutationRates["bias"] = BiasMutationChance
235
	genome.mutationRates["node"] = NodeMutationChance
236
	genome.mutationRates["enable"] = EnableMutationChance
237
	genome.mutationRates["disable"] = DisableMutationChance
238
	genome.mutationRates["step"] = StepSize
239
	
240
	return genome
241
end
242
243
function copyGenome(genome)
244
	local genome2 = newGenome()
245
	for g=1,#genome.genes do
246
		table.insert(genome2.genes, copyGene(genome.genes[g]))
247
	end
248
	genome2.maxneuron = genome.maxneuron
249
	genome2.mutationRates["connections"] = genome.mutationRates["connections"]
250
	genome2.mutationRates["link"] = genome.mutationRates["link"]
251
	genome2.mutationRates["bias"] = genome.mutationRates["bias"]
252
	genome2.mutationRates["node"] = genome.mutationRates["node"]
253
	genome2.mutationRates["enable"] = genome.mutationRates["enable"]
254
	genome2.mutationRates["disable"] = genome.mutationRates["disable"]
255
	
256
	return genome2
257
end
258
259
function basicGenome()
260
	local genome = newGenome()
261
	local innovation = 1
262
263
	genome.maxneuron = Inputs
264
	mutate(genome)
265
	
266
	return genome
267
end
268
269
function newGene()
270
	local gene = {}
271
	gene.into = 0
272
	gene.out = 0
273
	gene.weight = 0.0
274
	gene.enabled = true
275
	gene.innovation = 0
276
	
277
	return gene
278
end
279
280
function copyGene(gene)
281
	local gene2 = newGene()
282
	gene2.into = gene.into
283
	gene2.out = gene.out
284
	gene2.weight = gene.weight
285
	gene2.enabled = gene.enabled
286
	gene2.innovation = gene.innovation
287
	
288
	return gene2
289
end
290
291
function newNeuron()
292
	local neuron = {}
293
	neuron.incoming = {}
294
	neuron.value = 0.0
295
	
296
	return neuron
297
end
298
299
function generateNetwork(genome)
300
	local network = {}
301
	network.neurons = {}
302
	
303
	for i=1,Inputs do
304
		network.neurons[i] = newNeuron()
305
	end
306
	
307
	for o=1,Outputs do
308
		network.neurons[MaxNodes+o] = newNeuron()
309
	end
310
	
311
	table.sort(genome.genes, function (a,b)
312
		return (a.out < b.out)
313
	end)
314
	for i=1,#genome.genes do
315
		local gene = genome.genes[i]
316
		if gene.enabled then
317
			if network.neurons[gene.out] == nil then
318
				network.neurons[gene.out] = newNeuron()
319
			end
320
			local neuron = network.neurons[gene.out]
321
			table.insert(neuron.incoming, gene)
322
			if network.neurons[gene.into] == nil then
323
				network.neurons[gene.into] = newNeuron()
324
			end
325
		end
326
	end
327
	
328
	genome.network = network
329
end
330
331
function evaluateNetwork(network, inputs)
332
	table.insert(inputs, 1)
333
	if #inputs ~= Inputs then
334
		console.writeline("Incorrect number of neural network inputs.")
335
		return {}
336
	end
337
	
338
	for i=1,Inputs do
339
		network.neurons[i].value = inputs[i]
340
	end
341
	
342
	for _,neuron in pairs(network.neurons) do
343
		local sum = 0
344
		for j = 1,#neuron.incoming do
345
			local incoming = neuron.incoming[j]
346
			local other = network.neurons[incoming.into]
347
			sum = sum + incoming.weight * other.value
348
		end
349
		
350
		if #neuron.incoming > 0 then
351
			neuron.value = sigmoid(sum)
352
		end
353
	end
354
	
355
	local outputs = {}
356
	for o=1,Outputs do
357
		local button = "P1 " .. ButtonNames[o]
358
		if network.neurons[MaxNodes+o].value > 0 then
359
			outputs[button] = true
360
		else
361
			outputs[button] = false
362
		end
363
	end
364
	
365
	return outputs
366
end
367
368
function crossover(g1, g2)
369
	-- Make sure g1 is the higher fitness genome
370
	if g2.fitness > g1.fitness then
371
		tempg = g1
372
		g1 = g2
373
		g2 = tempg
374
	end
375
376
	local child = newGenome()
377
	
378
	local innovations2 = {}
379
	for i=1,#g2.genes do
380
		local gene = g2.genes[i]
381
		innovations2[gene.innovation] = gene
382
	end
383
	
384
	for i=1,#g1.genes do
385
		local gene1 = g1.genes[i]
386
		local gene2 = innovations2[gene1.innovation]
387
		if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then
388
			table.insert(child.genes, copyGene(gene2))
389
		else
390
			table.insert(child.genes, copyGene(gene1))
391
		end
392
	end
393
	
394
	child.maxneuron = math.max(g1.maxneuron,g2.maxneuron)
395
	
396
	for mutation,rate in pairs(g1.mutationRates) do
397
		child.mutationRates[mutation] = rate
398
	end
399
	
400
	return child
401
end
402
403
function randomNeuron(genes, nonInput)
404
	local neurons = {}
405
	if not nonInput then
406
		for i=1,Inputs do
407
			neurons[i] = true
408
		end
409
	end
410
	for o=1,Outputs do
411
		neurons[MaxNodes+o] = true
412
	end
413
	for i=1,#genes do
414
		if (not nonInput) or genes[i].into > Inputs then
415
			neurons[genes[i].into] = true
416
		end
417
		if (not nonInput) or genes[i].out > Inputs then
418
			neurons[genes[i].out] = true
419
		end
420
	end
421
422
	local count = 0
423
	for _,_ in pairs(neurons) do
424
		count = count + 1
425
	end
426
	local n = math.random(1, count)
427
	
428
	for k,v in pairs(neurons) do
429
		n = n-1
430
		if n == 0 then
431
			return k
432
		end
433
	end
434
	
435
	return 0
436
end
437
438
function containsLink(genes, link)
439
	for i=1,#genes do
440
		local gene = genes[i]
441
		if gene.into == link.into and gene.out == link.out then
442
			return true
443
		end
444
	end
445
end
446
447
function pointMutate(genome)
448
	local step = genome.mutationRates["step"]
449
	
450
	for i=1,#genome.genes do
451
		local gene = genome.genes[i]
452
		if math.random() < PerturbChance then
453
			gene.weight = gene.weight + math.random() * step*2 - step
454
		else
455
			gene.weight = math.random()*4-2
456
		end
457
	end
458
end
459
460
function linkMutate(genome, forceBias)
461
	local neuron1 = randomNeuron(genome.genes, false)
462
	local neuron2 = randomNeuron(genome.genes, true)
463
	 
464
	local newLink = newGene()
465
	if neuron1 <= Inputs and neuron2 <= Inputs then
466
		--Both input nodes
467
		return
468
	end
469
	if neuron2 <= Inputs then
470
		-- Swap output and input
471
		local temp = neuron1
472
		neuron1 = neuron2
473
		neuron2 = temp
474
	end
475
476
	newLink.into = neuron1
477
	newLink.out = neuron2
478
	if forceBias then
479
		newLink.into = Inputs
480
	end
481
	
482
	if containsLink(genome.genes, newLink) then
483
		return
484
	end
485
	newLink.innovation = newInnovation()
486
	newLink.weight = math.random()*4-2
487
	
488
	table.insert(genome.genes, newLink)
489
end
490
491
function nodeMutate(genome)
492
	if #genome.genes == 0 then
493
		return
494
	end
495
496
	genome.maxneuron = genome.maxneuron + 1
497
498
	local gene = genome.genes[math.random(1,#genome.genes)]
499
	if not gene.enabled then
500
		return
501
	end
502
	gene.enabled = false
503
	
504
	local gene1 = copyGene(gene)
505
	gene1.out = genome.maxneuron
506
	gene1.weight = 1.0
507
	gene1.innovation = newInnovation()
508
	gene1.enabled = true
509
	table.insert(genome.genes, gene1)
510
	
511
	local gene2 = copyGene(gene)
512
	gene2.into = genome.maxneuron
513
	gene2.innovation = newInnovation()
514
	gene2.enabled = true
515
	table.insert(genome.genes, gene2)
516
end
517
518
function enableDisableMutate(genome, enable)
519
	local candidates = {}
520
	for _,gene in pairs(genome.genes) do
521
		if gene.enabled == not enable then
522
			table.insert(candidates, gene)
523
		end
524
	end
525
	
526
	if #candidates == 0 then
527
		return
528
	end
529
	
530
	local gene = candidates[math.random(1,#candidates)]
531
	gene.enabled = not gene.enabled
532
end
533
534
function mutate(genome)
535
	for mutation,rate in pairs(genome.mutationRates) do
536
		if math.random(1,2) == 1 then
537
			genome.mutationRates[mutation] = 0.95*rate
538
		else
539
			genome.mutationRates[mutation] = 1.05263*rate
540
		end
541
	end
542
543
	if math.random() < genome.mutationRates["connections"] then
544
		pointMutate(genome)
545
	end
546
	
547
	local p = genome.mutationRates["link"]
548
	while p > 0 do
549
		if math.random() < p then
550
			linkMutate(genome, false)
551
		end
552
		p = p - 1
553
	end
554
555
	p = genome.mutationRates["bias"]
556
	while p > 0 do
557
		if math.random() < p then
558
			linkMutate(genome, true)
559
		end
560
		p = p - 1
561
	end
562
	
563
	p = genome.mutationRates["node"]
564
	while p > 0 do
565
		if math.random() < p then
566
			nodeMutate(genome)
567
		end
568
		p = p - 1
569
	end
570
	
571
	p = genome.mutationRates["enable"]
572
	while p > 0 do
573
		if math.random() < p then
574
			enableDisableMutate(genome, true)
575
		end
576
		p = p - 1
577
	end
578
579
	p = genome.mutationRates["disable"]
580
	while p > 0 do
581
		if math.random() < p then
582
			enableDisableMutate(genome, false)
583
		end
584
		p = p - 1
585
	end
586
end
587
588
function disjoint(genes1, genes2)
589
	local i1 = {}
590
	for i = 1,#genes1 do
591
		local gene = genes1[i]
592
		i1[gene.innovation] = true
593
	end
594
595
	local i2 = {}
596
	for i = 1,#genes2 do
597
		local gene = genes2[i]
598
		i2[gene.innovation] = true
599
	end
600
	
601
	local disjointGenes = 0
602
	for i = 1,#genes1 do
603
		local gene = genes1[i]
604
		if not i2[gene.innovation] then
605
			disjointGenes = disjointGenes+1
606
		end
607
	end
608
	
609
	for i = 1,#genes2 do
610
		local gene = genes2[i]
611
		if not i1[gene.innovation] then
612
			disjointGenes = disjointGenes+1
613
		end
614
	end
615
	
616
	local n = math.max(#genes1, #genes2)
617
	
618
	return disjointGenes / n
619
end
620
621
function weights(genes1, genes2)
622
	local i2 = {}
623
	for i = 1,#genes2 do
624
		local gene = genes2[i]
625
		i2[gene.innovation] = gene
626
	end
627
628
	local sum = 0
629
	local coincident = 0
630
	for i = 1,#genes1 do
631
		local gene = genes1[i]
632
		if i2[gene.innovation] ~= nil then
633
			local gene2 = i2[gene.innovation]
634
			sum = sum + math.abs(gene.weight - gene2.weight)
635
			coincident = coincident + 1
636
		end
637
	end
638
	
639
	return sum / coincident
640
end
641
	
642
function sameSpecies(genome1, genome2)
643
	local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes)
644
	local dw = DeltaWeights*weights(genome1.genes, genome2.genes) 
645
	return dd + dw < DeltaThreshold
646
end
647
648
function rankGlobally()
649
	local global = {}
650
	for s = 1,#pool.species do
651
		local species = pool.species[s]
652
		for g = 1,#species.genomes do
653
			table.insert(global, species.genomes[g])
654
		end
655
	end
656
	table.sort(global, function (a,b)
657
		return (a.fitness < b.fitness)
658
	end)
659
	
660
	for g=1,#global do
661
		global[g].globalRank = g
662
	end
663
end
664
665
function calculateAverageFitness(species)
666
	local total = 0
667
	
668
	for g=1,#species.genomes do
669
		local genome = species.genomes[g]
670
		total = total + genome.globalRank
671
	end
672
	
673
	species.averageFitness = total / #species.genomes
674
end
675
676
function totalAverageFitness()
677
	local total = 0
678
	for s = 1,#pool.species do
679
		local species = pool.species[s]
680
		total = total + species.averageFitness
681
	end
682
683
	return total
684
end
685
686
function cullSpecies(cutToOne)
687
	for s = 1,#pool.species do
688
		local species = pool.species[s]
689
		
690
		table.sort(species.genomes, function (a,b)
691
			return (a.fitness > b.fitness)
692
		end)
693
		
694
		local remaining = math.ceil(#species.genomes/2)
695
		if cutToOne then
696
			remaining = 1
697
		end
698
		while #species.genomes > remaining do
699
			table.remove(species.genomes)
700
		end
701
	end
702
end
703
704
function breedChild(species)
705
	local child = {}
706
	if math.random() < CrossoverChance then
707
		g1 = species.genomes[math.random(1, #species.genomes)]
708
		g2 = species.genomes[math.random(1, #species.genomes)]
709
		child = crossover(g1, g2)
710
	else
711
		g = species.genomes[math.random(1, #species.genomes)]
712
		child = copyGenome(g)
713
	end
714
	
715
	mutate(child)
716
	
717
	return child
718
end
719
720
function removeStaleSpecies()
721
	local survived = {}
722
723
	for s = 1,#pool.species do
724
		local species = pool.species[s]
725
		
726
		table.sort(species.genomes, function (a,b)
727
			return (a.fitness > b.fitness)
728
		end)
729
		
730
		if species.genomes[1].fitness > species.topFitness then
731
			species.topFitness = species.genomes[1].fitness
732
			species.staleness = 0
733
		else
734
			species.staleness = species.staleness + 1
735
		end
736
		if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then
737
			table.insert(survived, species)
738
		end
739
	end
740
741
	pool.species = survived
742
end
743
744
function removeWeakSpecies()
745
	local survived = {}
746
747
	local sum = totalAverageFitness()
748
	for s = 1,#pool.species do
749
		local species = pool.species[s]
750
		breed = math.floor(species.averageFitness / sum * Population)
751
		if breed >= 1 then
752
			table.insert(survived, species)
753
		end
754
	end
755
756
	pool.species = survived
757
end
758
759
760
function addToSpecies(child)
761
	local foundSpecies = false
762
	for s=1,#pool.species do
763
		local species = pool.species[s]
764
		if not foundSpecies and sameSpecies(child, species.genomes[1]) then
765
			table.insert(species.genomes, child)
766
			foundSpecies = true
767
		end
768
	end
769
	
770
	if not foundSpecies then
771
		local childSpecies = newSpecies()
772
		table.insert(childSpecies.genomes, child)
773
		table.insert(pool.species, childSpecies)
774
	end
775
end
776
777
function newGeneration()
778
	cullSpecies(false) -- Cull the bottom half of each species
779
	rankGlobally()
780
	removeStaleSpecies()
781
	rankGlobally()
782
	for s = 1,#pool.species do
783
		local species = pool.species[s]
784
		calculateAverageFitness(species)
785
	end
786
	removeWeakSpecies()
787
	local sum = totalAverageFitness()
788
	local children = {}
789
	for s = 1,#pool.species do
790
		local species = pool.species[s]
791
		breed = math.floor(species.averageFitness / sum * Population) - 1
792
		for i=1,breed do
793
			table.insert(children, breedChild(species))
794
		end
795
	end
796
	cullSpecies(true) -- Cull all but the top member of each species
797
	while #children + #pool.species < Population do
798
		local species = pool.species[math.random(1, #pool.species)]
799
		table.insert(children, breedChild(species))
800
	end
801
	for c=1,#children do
802
		local child = children[c]
803
		addToSpecies(child)
804
	end
805
	
806
	pool.generation = pool.generation + 1
807
	
808
	writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
809
end
810
	
811
function initializePool()
812
	pool = newPool()
813
814
	for i=1,Population do
815
		basic = basicGenome()
816
		addToSpecies(basic)
817
	end
818
819
	initializeRun()
820
end
821
822
function clearJoypad()
823
	controller = {}
824
	for b = 1,#ButtonNames do
825
		controller["P1 " .. ButtonNames[b]] = false
826
	end
827
	joypad.set(controller)
828
end
829
830
function initializeRun()
831
	savestate.load(Filename);
832
	rightmost = 0
833
	pool.currentFrame = 0
834
	timeout = TimeoutConstant
835
	clearJoypad()
836
	
837
	local species = pool.species[pool.currentSpecies]
838
	local genome = species.genomes[pool.currentGenome]
839
	generateNetwork(genome)
840
	evaluateCurrent()
841
end
842
843
function evaluateCurrent()
844
	local species = pool.species[pool.currentSpecies]
845
	local genome = species.genomes[pool.currentGenome]
846
847
	inputs = getInputs()
848
	controller = evaluateNetwork(genome.network, inputs)
849
	
850
	if controller["P1 Left"] and controller["P1 Right"] then
851
		controller["P1 Left"] = false
852
		controller["P1 Right"] = false
853
	end
854
	if controller["P1 Up"] and controller["P1 Down"] then
855
		controller["P1 Up"] = false
856
		controller["P1 Down"] = false
857
	end
858
859
	joypad.set(controller)
860
end
861
862
if pool == nil then
863
	initializePool()
864
end
865
866
867
function nextGenome()
868
	pool.currentGenome = pool.currentGenome + 1
869
	if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then
870
		pool.currentGenome = 1
871
		pool.currentSpecies = pool.currentSpecies+1
872
		if pool.currentSpecies > #pool.species then
873
			newGeneration()
874
			pool.currentSpecies = 1
875
		end
876
	end
877
end
878
879
function fitnessAlreadyMeasured()
880
	local species = pool.species[pool.currentSpecies]
881
	local genome = species.genomes[pool.currentGenome]
882
	
883
	return genome.fitness ~= 0
884
end
885
886
function displayGenome(genome)
887
	local network = genome.network
888
	local cells = {}
889
	local i = 1
890
	local cell = {}
891
	for dy=-BoxRadius,BoxRadius do
892
		for dx=-BoxRadius,BoxRadius do
893
			cell = {}
894
			cell.x = 50+5*dx
895
			cell.y = 70+5*dy
896
			cell.value = network.neurons[i].value
897
			cells[i] = cell
898
			i = i + 1
899
		end
900
	end
901
	local biasCell = {}
902
	biasCell.x = 80
903
	biasCell.y = 110
904
	biasCell.value = network.neurons[Inputs].value
905
	cells[Inputs] = biasCell
906
	
907
	for o = 1,Outputs do
908
		cell = {}
909
		cell.x = 220
910
		cell.y = 30 + 8 * o
911
		cell.value = network.neurons[MaxNodes + o].value
912
		cells[MaxNodes+o] = cell
913
		local color
914
		if cell.value > 0 then
915
			color = 0xFF0000FF
916
		else
917
			color = 0xFF000000
918
		end
919
		gui.drawText(223, 24+8*o, ButtonNames[o], color, 9)
920
	end
921
	
922
	for n,neuron in pairs(network.neurons) do
923
		cell = {}
924
		if n > Inputs and n <= MaxNodes then
925
			cell.x = 140
926
			cell.y = 40
927
			cell.value = neuron.value
928
			cells[n] = cell
929
		end
930
	end
931
	
932
	for n=1,4 do
933
		for _,gene in pairs(genome.genes) do
934
			if gene.enabled then
935
				local c1 = cells[gene.into]
936
				local c2 = cells[gene.out]
937
				if gene.into > Inputs and gene.into <= MaxNodes then
938
					c1.x = 0.75*c1.x + 0.25*c2.x
939
					if c1.x >= c2.x then
940
						c1.x = c1.x - 40
941
					end
942
					if c1.x < 90 then
943
						c1.x = 90
944
					end
945
					
946
					if c1.x > 220 then
947
						c1.x = 220
948
					end
949
					c1.y = 0.75*c1.y + 0.25*c2.y
950
					
951
				end
952
				if gene.out > Inputs and gene.out <= MaxNodes then
953
					c2.x = 0.25*c1.x + 0.75*c2.x
954
					if c1.x >= c2.x then
955
						c2.x = c2.x + 40
956
					end
957
					if c2.x < 90 then
958
						c2.x = 90
959
					end
960
					if c2.x > 220 then
961
						c2.x = 220
962
					end
963
					c2.y = 0.25*c1.y + 0.75*c2.y
964
				end
965
			end
966
		end
967
	end
968
	
969
	gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080)
970
	for n,cell in pairs(cells) do
971
		if n > Inputs or cell.value ~= 0 then
972
			local color = math.floor((cell.value+1)/2*256)
973
			if color > 255 then color = 255 end
974
			if color < 0 then color = 0 end
975
			local opacity = 0xFF000000
976
			if cell.value == 0 then
977
				opacity = 0x50000000
978
			end
979
			color = opacity + color*0x10000 + color*0x100 + color
980
			gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color)
981
		end
982
	end
983
	for _,gene in pairs(genome.genes) do
984
		if gene.enabled then
985
			local c1 = cells[gene.into]
986
			local c2 = cells[gene.out]
987
			local opacity = 0xA0000000
988
			if c1.value == 0 then
989
				opacity = 0x20000000
990
			end
991
			
992
			local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80)
993
			if gene.weight > 0 then 
994
				color = opacity + 0x8000 + 0x10000*color
995
			else
996
				color = opacity + 0x800000 + 0x100*color
997
			end
998
			gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color)
999
		end
1000
	end
1001
	
1002
	gui.drawBox(49,71,51,78,0x00000000,0x80FF0000)
1003
	
1004
	if forms.ischecked(showMutationRates) then
1005
		local pos = 100
1006
		for mutation,rate in pairs(genome.mutationRates) do
1007
			gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10)
1008
			pos = pos + 8
1009
		end
1010
	end
1011
end
1012
1013
function writeFile(filename)
1014
        local file = io.open(filename, "w")
1015
	file:write(pool.generation .. "\n")
1016
	file:write(pool.maxFitness .. "\n")
1017
	file:write(#pool.species .. "\n")
1018
        for n,species in pairs(pool.species) do
1019
		file:write(species.topFitness .. "\n")
1020
		file:write(species.staleness .. "\n")
1021
		file:write(#species.genomes .. "\n")
1022
		for m,genome in pairs(species.genomes) do
1023
			file:write(genome.fitness .. "\n")
1024
			file:write(genome.maxneuron .. "\n")
1025
			for mutation,rate in pairs(genome.mutationRates) do
1026
				file:write(mutation .. "\n")
1027
				file:write(rate .. "\n")
1028
			end
1029
			file:write("done\n")
1030
			
1031
			file:write(#genome.genes .. "\n")
1032
			for l,gene in pairs(genome.genes) do
1033
				file:write(gene.into .. " ")
1034
				file:write(gene.out .. " ")
1035
				file:write(gene.weight .. " ")
1036
				file:write(gene.innovation .. " ")
1037
				if(gene.enabled) then
1038
					file:write("1\n")
1039
				else
1040
					file:write("0\n")
1041
				end
1042
			end
1043
		end
1044
        end
1045
        file:close()
1046
end
1047
1048
function savePool()
1049
	local filename = forms.gettext(saveLoadFile)
1050
	writeFile(filename)
1051
end
1052
1053
function loadFile(filename)
1054
        local file = io.open(filename, "r")
1055
	pool = newPool()
1056
	pool.generation = file:read("*number")
1057
	pool.maxFitness = file:read("*number")
1058
	forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
1059
        local numSpecies = file:read("*number")
1060
        for s=1,numSpecies do
1061
		local species = newSpecies()
1062
		table.insert(pool.species, species)
1063
		species.topFitness = file:read("*number")
1064
		species.staleness = file:read("*number")
1065
		local numGenomes = file:read("*number")
1066
		for g=1,numGenomes do
1067
			local genome = newGenome()
1068
			table.insert(species.genomes, genome)
1069
			genome.fitness = file:read("*number")
1070
			genome.maxneuron = file:read("*number")
1071
			local line = file:read("*line")
1072
			while line ~= "done" do
1073
				genome.mutationRates[line] = file:read("*number")
1074
				line = file:read("*line")
1075
			end
1076
			local numGenes = file:read("*number")
1077
			for n=1,numGenes do
1078
				local gene = newGene()
1079
				table.insert(genome.genes, gene)
1080
				local enabled
1081
				gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number")
1082
				if enabled == 0 then
1083
					gene.enabled = false
1084
				else
1085
					gene.enabled = true
1086
				end
1087
				
1088
			end
1089
		end
1090
	end
1091
        file:close()
1092
	
1093
	while fitnessAlreadyMeasured() do
1094
		nextGenome()
1095
	end
1096
	initializeRun()
1097
	pool.currentFrame = pool.currentFrame + 1
1098
end
1099
 
1100
function loadPool()
1101
	local filename = forms.gettext(saveLoadFile)
1102
	loadFile(filename)
1103
end
1104
1105
function playTop()
1106
	local maxfitness = 0
1107
	local maxs, maxg
1108
	for s,species in pairs(pool.species) do
1109
		for g,genome in pairs(species.genomes) do
1110
			if genome.fitness > maxfitness then
1111
				maxfitness = genome.fitness
1112
				maxs = s
1113
				maxg = g
1114
			end
1115
		end
1116
	end
1117
	
1118
	pool.currentSpecies = maxs
1119
	pool.currentGenome = maxg
1120
	pool.maxFitness = maxfitness
1121
	forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
1122
	initializeRun()
1123
	pool.currentFrame = pool.currentFrame + 1
1124
	return
1125
end
1126
1127
function onExit()
1128
	forms.destroy(form)
1129
end
1130
1131
writeFile("temp.pool")
1132
1133
event.onexit(onExit)
1134
1135
form = forms.newform(200, 260, "Fitness")
1136
maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8)
1137
showNetwork = forms.checkbox(form, "Show Map", 5, 30)
1138
showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52)
1139
restartButton = forms.button(form, "Restart", initializePool, 5, 77)
1140
saveButton = forms.button(form, "Save", savePool, 5, 102)
1141
loadButton = forms.button(form, "Load", loadPool, 80, 102)
1142
saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148)
1143
saveLoadLabel = forms.label(form, "Save/Load:", 5, 129)
1144
playTopButton = forms.button(form, "Play Top", playTop, 5, 170)
1145
hideBanner = forms.checkbox(form, "Hide Banner", 5, 190)
1146
1147
1148
while true do
1149
	local backgroundColor = 0xD0FFFFFF
1150
	if not forms.ischecked(hideBanner) then
1151
		gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor)
1152
	end
1153
1154
	local species = pool.species[pool.currentSpecies]
1155
	local genome = species.genomes[pool.currentGenome]
1156
	
1157
	if forms.ischecked(showNetwork) then
1158
		displayGenome(genome)
1159
	end
1160
	
1161
	if pool.currentFrame%5 == 0 then
1162
		evaluateCurrent()
1163
	end
1164
1165
	joypad.set(controller)
1166
1167
	getPositions()
1168
	if marioX > rightmost then
1169
		rightmost = marioX
1170
		timeout = TimeoutConstant
1171
	end
1172
	
1173
	timeout = timeout - 1
1174
	
1175
	
1176
	local timeoutBonus = pool.currentFrame / 4
1177
	if timeout + timeoutBonus <= 0 then
1178
		local fitness = rightmost - pool.currentFrame / 2
1179
		if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then
1180
			fitness = fitness + 1000
1181
		end
1182
		if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then
1183
			fitness = fitness + 1000
1184
		end
1185
		if fitness == 0 then
1186
			fitness = -1
1187
		end
1188
		genome.fitness = fitness
1189
		
1190
		if fitness > pool.maxFitness then
1191
			pool.maxFitness = fitness
1192
			forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness))
1193
			writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile))
1194
		end
1195
		
1196
		console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness)
1197
		pool.currentSpecies = 1
1198
		pool.currentGenome = 1
1199
		while fitnessAlreadyMeasured() do
1200
			nextGenome()
1201
		end
1202
		initializeRun()
1203
	end
1204
1205
	local measured = 0
1206
	local total = 0
1207
	for _,species in pairs(pool.species) do
1208
		for _,genome in pairs(species.genomes) do
1209
			total = total + 1
1210
			if genome.fitness ~= 0 then
1211
				measured = measured + 1
1212
			end
1213
		end
1214
	end
1215
	if not forms.ischecked(hideBanner) then
1216
		gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11)
1217
		gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11)
1218
		gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11)
1219
	end
1220
		
1221
	pool.currentFrame = pool.currentFrame + 1
1222
1223
	emu.frameadvance();
1224
end