Guest User

Untitled

a guest
Aug 18th, 2017
97
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. -------------------------------
  2. -- main
  3. -------------------------------
  4.  
  5. mx=0
  6. my=0
  7. mb=0
  8. lmb=0
  9.  
  10. function _init()
  11. poke(0x5f2d,1)
  12. cls()
  13. net=network:new(
  14. {2,3,1},
  15. 50,1
  16. )
  17. net:init_pos()
  18. num_trains=1000
  19. end
  20.  
  21. function update_net()
  22. net:activate({
  23. net[1].cells[1].signal,
  24. net[1].cells[2].signal
  25. })
  26. end
  27.  
  28. function _update()
  29. mx=stat(32)
  30. my=stat(33)
  31. lmb=mb
  32. mb=stat(34)
  33. if num_trains>0 then
  34. num_trains-=10
  35. for i=0,10 do
  36. net:bp({0,0},{0})
  37. net:bp({0,1},{1})
  38. net:bp({1,0},{1})
  39. net:bp({1,1},{0})
  40. end
  41. else
  42. act=true
  43. end
  44. end
  45.  
  46. function _draw()
  47. cls()
  48. if num_trains>0 then
  49. print(num_trains,2,2,6)
  50. end
  51. net:render()
  52. spr(0,mx,my)
  53. end
  54.  
  55. -------------------------------
  56. -- math
  57. -------------------------------
  58.  
  59. function sqr(a) return a*a end
  60.  
  61. function pow(x,a)
  62. if (a==0) return 1
  63. if (a<0) x,a=1/x,-a
  64. local ret,a0,xn=1,flr(a),x
  65. a-=a0
  66. while a0>=1 do
  67. if (a0%2>=1) ret*=xn
  68. xn,a0=xn*xn,shr(a0,1)
  69. end
  70. while a>0 do
  71. while a<1 do x,a=sqrt(x),a+a end
  72. ret,a=ret*x,a-1
  73. end
  74. return ret
  75. end
  76.  
  77. function exp(x)
  78. if x>20 then
  79. printh(x)
  80. end
  81. return pow(2.718281,x)
  82. end
  83.  
  84. -------------------------------
  85. -- cells
  86. -------------------------------
  87.  
  88. cell={}
  89.  
  90. function cell:new(ni)
  91. local n={
  92. delta=0,
  93. weights={},
  94. signal=0
  95. }
  96. for i=1,ni do
  97. n.weights[i]=rnd(0.1)
  98. end
  99.  
  100. setmetatable(n,self)
  101. self.__index=self
  102. return n
  103. end
  104.  
  105. function cell:activate(
  106. inp,bias,t
  107. )
  108. local s=bias
  109. local w=self.weights
  110. for i=1,#w do
  111. s+=w[i]*inp[i]
  112. end
  113. --printh("-------")
  114. --printh("s:"..s)
  115. --printh("t:"..t)
  116. self.signal=
  117. 1/(1+exp(-1*s/t))
  118. --printh("signal:"..self.signal)
  119. end
  120.  
  121. function cell:render()
  122. if act and lmb==0 and mb~=0 then
  123.  
  124. if sqrt(sqr(mx-self.x)
  125. +sqr(my-self.y))<4 then
  126. if self.signal>0.5 then
  127. self.signal=0
  128. else
  129. self.signal=1
  130. end
  131. update_net()
  132. end
  133. end
  134.  
  135. local cl=self.signal>0.5 and 11 or 8
  136. circfill(self.x,self.y,4,cl)
  137. end
  138.  
  139. -------------------------------
  140. -- layers
  141. -------------------------------
  142.  
  143. layer={}
  144.  
  145. function layer:new(nc,ni)
  146. nc=nc or 1
  147. ni=ni or 1
  148. local n={}
  149. for i=1,nc do
  150. n[i]=cell:new(ni)
  151. end
  152. local l={
  153. cells=n,
  154. bias=rnd()
  155. }
  156.  
  157. setmetatable(l,self)
  158. self.__index=self
  159. return l
  160. end
  161.  
  162. function layer:init_pos(i)
  163. local prc=99/(#net-1)
  164. local s=#net==1 and 60 or 10
  165. local x=s+(i-1)*prc
  166. local prc=99/(#self.cells-1)
  167. local s=#self.cells==1 and 60 or 10
  168. for i=1,#self.cells do
  169. self.cells[i].x=x+4
  170. self.cells[i].y=s+(i-1)*prc+4
  171. end
  172. end
  173.  
  174. function layer:render()
  175. for c in all(self.cells) do
  176. c:render()
  177. end
  178. end
  179.  
  180. function layer:render_connections(l)
  181. for c in all(self.cells) do
  182. for c2 in all(l.cells) do
  183. line(c.x,c.y,c2.x,c2.y,5)
  184. end
  185. end
  186. end
  187.  
  188. -------------------------------
  189. -- network
  190. -------------------------------
  191.  
  192. network={}
  193.  
  194. function network:new(
  195. l,lr,t
  196. )
  197. local n={
  198. lr=lr,
  199. t=t
  200. }
  201. n[1]=layer:new(l[1],l[1])
  202. for i=2,#l do
  203. n[i]=layer:new(l[i],l[i-1])
  204. end
  205.  
  206. setmetatable(n,self)
  207. self.__index=self
  208. return n
  209. end
  210.  
  211. function network:init_pos()
  212. for i=1,#self do
  213. self[i]:init_pos(i)
  214. end
  215. end
  216.  
  217. function network:render()
  218. for i=1,#self-1 do
  219. self[i]:render_connections(self[i+1])
  220. end
  221. for i=1,#self do
  222. self[i]:render()
  223. end
  224. end
  225.  
  226. function network:activate(inp)
  227. local t=self.t
  228. for i=1,#inp do
  229. self[1].cells[i].signal=inp[i]
  230. end
  231. for i=2,#self do
  232. local pi={}
  233. local c=self[i].cells
  234. local pc=self[i-1].cells
  235. for m=1,#pc do
  236. pi[m]=pc[m].signal
  237. end
  238. local pb=self[i].bias
  239. for j=1,#c do
  240. c[j]:activate(pi,pb,t)
  241. end
  242. end
  243. end
  244.  
  245. function network:bp(inp,out)
  246. self:activate(inp)
  247. local ns=#self
  248. local lr=self.lr
  249. for i=ns,2,-1 do
  250. local nc=#self[i].cells
  251. local cells=self[i].cells
  252. for j=1,nc do
  253. local s=cells[j].signal
  254. if i~=ns then
  255. local wd=0
  256. local l=self[i+1].cells
  257. for k=1,#l do
  258. wd+=l[k].weights[j]
  259. *l[k].delta
  260. end
  261. cells[j].delta=s*(1-s)*wd
  262. else
  263. cells[j].delta=(out[j]-s)*s*(1-s)
  264. end
  265. end
  266. end
  267. for i=2,ns do
  268. self[i].bias=self[i].cells[#self[i].cells].delta*lr
  269. for j=1,#self[i].cells do
  270. for k=1,#self[i].cells[j].weights do
  271. local weights=self[i].cells[j].weights
  272. weights[k]=weights[k]+self[i].cells[j].delta*lr*self[i-1].cells[k].signal
  273. end
  274. end
  275. end
  276. end
RAW Paste Data