Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- -------------------------------
- -- main
- -------------------------------
- mx=0
- my=0
- mb=0
- lmb=0
- function _init()
- poke(0x5f2d,1)
- cls()
- net=network:new(
- {2,3,1},
- 50,1
- )
- net:init_pos()
- num_trains=1000
- end
- function update_net()
- net:activate({
- net[1].cells[1].signal,
- net[1].cells[2].signal
- })
- end
- function _update()
- mx=stat(32)
- my=stat(33)
- lmb=mb
- mb=stat(34)
- if num_trains>0 then
- num_trains-=10
- for i=0,10 do
- net:bp({0,0},{0})
- net:bp({0,1},{1})
- net:bp({1,0},{1})
- net:bp({1,1},{0})
- end
- else
- act=true
- end
- end
- function _draw()
- cls()
- if num_trains>0 then
- print(num_trains,2,2,6)
- end
- net:render()
- spr(0,mx,my)
- end
- -------------------------------
- -- math
- -------------------------------
- function sqr(a) return a*a end
- function pow(x,a)
- if (a==0) return 1
- if (a<0) x,a=1/x,-a
- local ret,a0,xn=1,flr(a),x
- a-=a0
- while a0>=1 do
- if (a0%2>=1) ret*=xn
- xn,a0=xn*xn,shr(a0,1)
- end
- while a>0 do
- while a<1 do x,a=sqrt(x),a+a end
- ret,a=ret*x,a-1
- end
- return ret
- end
- function exp(x)
- if x>20 then
- printh(x)
- end
- return pow(2.718281,x)
- end
- -------------------------------
- -- cells
- -------------------------------
- cell={}
- function cell:new(ni)
- local n={
- delta=0,
- weights={},
- signal=0
- }
- for i=1,ni do
- n.weights[i]=rnd(0.1)
- end
- setmetatable(n,self)
- self.__index=self
- return n
- end
- function cell:activate(
- inp,bias,t
- )
- local s=bias
- local w=self.weights
- for i=1,#w do
- s+=w[i]*inp[i]
- end
- --printh("-------")
- --printh("s:"..s)
- --printh("t:"..t)
- self.signal=
- 1/(1+exp(-1*s/t))
- --printh("signal:"..self.signal)
- end
- function cell:render()
- if act and lmb==0 and mb~=0 then
- if sqrt(sqr(mx-self.x)
- +sqr(my-self.y))<4 then
- if self.signal>0.5 then
- self.signal=0
- else
- self.signal=1
- end
- update_net()
- end
- end
- local cl=self.signal>0.5 and 11 or 8
- circfill(self.x,self.y,4,cl)
- end
- -------------------------------
- -- layers
- -------------------------------
- layer={}
- function layer:new(nc,ni)
- nc=nc or 1
- ni=ni or 1
- local n={}
- for i=1,nc do
- n[i]=cell:new(ni)
- end
- local l={
- cells=n,
- bias=rnd()
- }
- setmetatable(l,self)
- self.__index=self
- return l
- end
- function layer:init_pos(i)
- local prc=99/(#net-1)
- local s=#net==1 and 60 or 10
- local x=s+(i-1)*prc
- local prc=99/(#self.cells-1)
- local s=#self.cells==1 and 60 or 10
- for i=1,#self.cells do
- self.cells[i].x=x+4
- self.cells[i].y=s+(i-1)*prc+4
- end
- end
- function layer:render()
- for c in all(self.cells) do
- c:render()
- end
- end
- function layer:render_connections(l)
- for c in all(self.cells) do
- for c2 in all(l.cells) do
- line(c.x,c.y,c2.x,c2.y,5)
- end
- end
- end
- -------------------------------
- -- network
- -------------------------------
- network={}
- function network:new(
- l,lr,t
- )
- local n={
- lr=lr,
- t=t
- }
- n[1]=layer:new(l[1],l[1])
- for i=2,#l do
- n[i]=layer:new(l[i],l[i-1])
- end
- setmetatable(n,self)
- self.__index=self
- return n
- end
- function network:init_pos()
- for i=1,#self do
- self[i]:init_pos(i)
- end
- end
- function network:render()
- for i=1,#self-1 do
- self[i]:render_connections(self[i+1])
- end
- for i=1,#self do
- self[i]:render()
- end
- end
- function network:activate(inp)
- local t=self.t
- for i=1,#inp do
- self[1].cells[i].signal=inp[i]
- end
- for i=2,#self do
- local pi={}
- local c=self[i].cells
- local pc=self[i-1].cells
- for m=1,#pc do
- pi[m]=pc[m].signal
- end
- local pb=self[i].bias
- for j=1,#c do
- c[j]:activate(pi,pb,t)
- end
- end
- end
- function network:bp(inp,out)
- self:activate(inp)
- local ns=#self
- local lr=self.lr
- for i=ns,2,-1 do
- local nc=#self[i].cells
- local cells=self[i].cells
- for j=1,nc do
- local s=cells[j].signal
- if i~=ns then
- local wd=0
- local l=self[i+1].cells
- for k=1,#l do
- wd+=l[k].weights[j]
- *l[k].delta
- end
- cells[j].delta=s*(1-s)*wd
- else
- cells[j].delta=(out[j]-s)*s*(1-s)
- end
- end
- end
- for i=2,ns do
- self[i].bias=self[i].cells[#self[i].cells].delta*lr
- for j=1,#self[i].cells do
- for k=1,#self[i].cells[j].weights do
- local weights=self[i].cells[j].weights
- weights[k]=weights[k]+self[i].cells[j].delta*lr*self[i-1].cells[k].signal
- end
- end
- end
- end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement