Guest User

Untitled

a guest
Aug 17th, 2017
52
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. -------------------------------
  2. -- main
  3. -------------------------------
  4.  
  5. mx=0
  6. my=0
  7. mb=0
  8. lmb=0
  9.  
  10. function _init()
  11. poke(0x5f2d,1)
  12. cls()
  13. net=network:new(
  14. {2,3,1},
  15. 50,0
  16. )
  17. net:init_pos()
  18. num_trains=3000
  19. end
  20.  
  21. function update_net()
  22. net:activate({
  23. net[1].cells[1].signal,
  24. net[1].cells[2].signal
  25. })
  26. end
  27.  
  28. function _update()
  29. mx=stat(32)
  30. my=stat(33)
  31. lmb=mb
  32. mb=stat(34)
  33. if num_trains>0 then
  34. num_trains-=100
  35. for i=0,100 do
  36. net:bp({0,0},{0})
  37. net:bp({0,1},{0})
  38. net:bp({1,0},{0})
  39. net:bp({1,1},{0})
  40. end
  41. else
  42. act=true
  43. end
  44. end
  45.  
  46. function _draw()
  47. cls()
  48. if num_trains>0 then
  49. print(num_trains,2,2,6)
  50. end
  51. net:render()
  52. spr(0,mx,my)
  53. end
  54.  
  55. -------------------------------
  56. -- math
  57. -------------------------------
  58.  
  59. function sqr(a) return a*a end
  60.  
  61. function pow(x,a)
  62. if (a==0) return 1
  63. if (a<0) x,a=1/x,-a
  64. local ret,a0,xn=1,flr(a),x
  65. a-=a0
  66. while a0>=1 do
  67. if (a0%2>=1) ret*=xn
  68. xn,a0=xn*xn,shr(a0,1)
  69. end
  70. while a>0 do
  71. while a<1 do x,a=sqrt(x),a+a end
  72. ret,a=ret*x,a-1
  73. end
  74. return ret
  75. end
  76.  
  77. function exp(x)
  78. if x>20 then
  79. printh(x)
  80. end
  81. return pow(2.718281,x)
  82. end
  83.  
  84. -------------------------------
  85. -- cells
  86. -------------------------------
  87.  
  88. cell={}
  89.  
  90. function cell:new(ni)
  91. local n={
  92. delta=0,
  93. weights={},
  94. signal=0
  95. }
  96. for i=1,ni do
  97. n.weights[i]=rnd(0.1)
  98. end
  99.  
  100. setmetatable(n,self)
  101. self.__index=self
  102. return n
  103. end
  104.  
  105. function cell:activate(
  106. inp,bias,t
  107. )
  108. local s=bias
  109. local w=self.weights
  110. for i=1,#w do
  111. s+=w[i]*inp[i]
  112. end
  113. self.signal=
  114. 1/(1+exp(-1*s/t))
  115. end
  116.  
  117. function cell:render()
  118. if act and lmb==0 and mb~=0 then
  119.  
  120. if sqrt(sqr(mx-self.x)
  121. +sqr(my-self.y))<4 then
  122. if self.signal>0.5 then
  123. self.signal=0
  124. else
  125. self.signal=1
  126. end
  127. update_net()
  128. end
  129. end
  130.  
  131. local cl=self.signal>0.5 and 11 or 8
  132. circfill(self.x,self.y,4,cl)
  133. end
  134.  
  135. -------------------------------
  136. -- layers
  137. -------------------------------
  138.  
  139. layer={}
  140.  
  141. function layer:new(nc,ni)
  142. nc=nc or 1
  143. ni=ni or 1
  144. local n={}
  145. for i=1,nc do
  146. n[i]=cell:new(ni)
  147. end
  148. local l={
  149. cells=n,
  150. bias=rnd()
  151. }
  152.  
  153. setmetatable(l,self)
  154. self.__index=self
  155. return l
  156. end
  157.  
  158. function layer:init_pos(i)
  159. local prc=99/(#net-1)
  160. local s=#net==1 and 60 or 10
  161. local x=s+(i-1)*prc
  162. local prc=99/(#self.cells-1)
  163. local s=#self.cells==1 and 60 or 10
  164. for i=1,#self.cells do
  165. self.cells[i].x=x+4
  166. self.cells[i].y=s+(i-1)*prc+4
  167. end
  168. end
  169.  
  170. function layer:render()
  171. for c in all(self.cells) do
  172. c:render()
  173. end
  174. end
  175.  
  176. function layer:render_connections(l)
  177. for c in all(self.cells) do
  178. for c2 in all(l.cells) do
  179. line(c.x,c.y,c2.x,c2.y,5)
  180. end
  181. end
  182. end
  183.  
  184. -------------------------------
  185. -- network
  186. -------------------------------
  187.  
  188. network={}
  189.  
  190. function network:new(
  191. l,lr,t
  192. )
  193. local n={
  194. lr=lr,
  195. t=t
  196. }
  197. n[1]=layer:new(l[1],l[1])
  198. for i=2,#l do
  199. n[i]=layer:new(l[i],l[i-1])
  200. end
  201.  
  202. setmetatable(n,self)
  203. self.__index=self
  204. return n
  205. end
  206.  
  207. function network:init_pos()
  208. for i=1,#self do
  209. self[i]:init_pos(i)
  210. end
  211. end
  212.  
  213. function network:render()
  214. for i=1,#self-1 do
  215. self[i]:render_connections(self[i+1])
  216. end
  217. for i=1,#self do
  218. self[i]:render()
  219. end
  220. end
  221.  
  222. function network:activate(inp)
  223. local t=self.t
  224. for i=1,#inp do
  225. self[1].cells[i].signal=inp[i]
  226. end
  227. for i=2,#self do
  228. local pi={}
  229. local c=self[i].cells
  230. local pc=self[i-1].cells
  231. for m=1,#pc do
  232. pi[m]=pc[m].signal
  233. end
  234. local pb=self[i].bias
  235. for j=1,#c do
  236. c[j]:activate(pi,pb,t)
  237. end
  238. end
  239. end
  240.  
  241. function network:bp(inp,out)
  242. self:activate(inp)
  243. local ns=#self
  244. local lr=self.lr
  245. for i=ns,2,-1 do
  246. local nc=#self[i].cells
  247. local cells=self[i].cells
  248. for j=1,nc do
  249. local s=cells[j].signal
  250. if i~=ns then
  251. local wd=0
  252. local l=self[i+1].cells
  253. for k=1,#l do
  254. wd+=l[k].weights[j]
  255. *l[k].delta
  256. end
  257. cells[j].delta=s*(1-s)*wd
  258. else
  259. cells[j].delta=(out[j]-s)*s*(1-s)
  260. end
  261. end
  262. end
  263. for i=2,ns do
  264. self[i].bias=self[i].cells[#self[i].cells].delta*lr
  265. for j=1,#self[i].cells do
  266. for k=1,#self[i].cells[j].weights do
  267. local weights=self[i].cells[j].weights
  268. weights[k]=weights[k]+self[i].cells[j].delta*lr*self[i-1].cells[k].signal
  269. end
  270. end
  271. end
  272. end
RAW Paste Data