Advertisement
Guest User

Neural Network

a guest
Nov 28th, 2021
129
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 3.38 KB | None | 0 0
  1. clear
  2. % initial
  3. lr=0.0000000034; % learning rate
  4. %lr=0.00000115; % learning rate
  5.  
  6. mom=1; % momentum
  7. iDim=1; % input dimension
  8.  
  9. %y=100+sin(2.*pi/12.*x);
  10.  
  11. % activation functions
  12. sigmoid=@(x) 1./(1+exp(x));
  13. dsigmoid=@(x) sigmoid(x).*(1-sigmoid(x));
  14. relu=@(x) max(x,0);
  15. drelu=@(x) x>0;
  16. activF={sigmoid dsigmoid;relu drelu}; % cell of poential activation func
  17.  
  18.  
  19. epochmax=20000;
  20. N=25;
  21. for L=1:5 % Neuron number
  22.     TN=L; % test number
  23.     bestEpoch=1;
  24.     %dimV=[iDim N 1 ]; %dimension of each layer, from input to output
  25.     dimV=cat(2,iDim,ones(1,L)*N,1);
  26.     dimPerLay=repelem(dimV,2) ;dimPerLay([1,end])=[];
  27.     dimPerLay=reshape(dimPerLay,2,[])';
  28.     xDiam=50; %span of x coordinates
  29.     activFidx=ones(1,size(dimV,2))*2;activFidx(end)=1; % last layer sigm, rest relu
  30.    
  31.     numSample=1000;
  32.     x=(rand(numSample,1)-0.5).*xDiam;
  33.     %x=linspace(-0.5*xDiam,0.5*xDiam,numSample)';
  34.     for i=1:size(dimPerLay,1);
  35.         w{i}=(rand(dimPerLay(i,:))-0.5)*1;
  36.         dw{i}=0;
  37.         b{i}=(rand(1,dimV(i+1))-0.5)*xDiam;
  38.         %b{i}=linspace(0,50,dimV(i+1));
  39.         db{i}=0;
  40.     end
  41.     yp{1}=x;
  42.     y=100+cos(pi./12.*x)+sin(pi./5.*x);
  43.    
  44.     for epoch=1:epochmax
  45.         for i=1:size(dimPerLay,1) % forward propagate
  46.             v=yp{i}*w{i}+b{i};
  47.             act=activF{activFidx,1};
  48.             dact=activF{activFidx,2};
  49.             %yp{i+1}=relu(v); %yp{i+1}=I;
  50.             yp{i+1}=act(v); %yp{i+1}=I;
  51.             %phi{i}=drelu(v);
  52.             phi{i}=dact(v);
  53.         end
  54.         e=y-yp{end};
  55.         bSize=size(e,1);
  56.         E(epoch)=sum(0.5*(e).^2)/bSize; % square error
  57.         % BP
  58.         %delta{size(dimPerLay,1)}=sum(E)/bSize;
  59.         delta{size(dimPerLay,1)}=e;
  60.         for i=size(dimPerLay,1)-1:-1:1
  61.             delta{i}=phi{i}.*(delta{i+1}*w{i+1}');
  62.         end
  63.         % weight update
  64.         for i=size(dimPerLay,1):-1:1
  65.             %d=mom*dw{i}+lr.*yp{i}'*delta{i};
  66.             gd=yp{i}'*delta{i};
  67.             dp=(dw{i}.*gd)>=0;% dot product
  68.             %dp(dp==0)=-0.5;
  69.             %         if epoch>2 & E(epoch)>E(epoch-1)
  70.             %             dp=-ones(size(dp));
  71.             %         end
  72.             d=mom*dw{i}.*(dp)+lr.*gd;
  73.             w{i}=w{i}+d;
  74.             dw{i}=d;
  75.             if E(epoch)<=E(bestEpoch)
  76.                 bestw=w; bestyp=yp;bestEpoch=epoch;
  77.             end
  78.             %gdb=sum(delta{i});
  79.             %dp=(db{i}.*gdb)>=0;% dot product
  80.             %dp(dp==0)=0.2;
  81.             %d=mom*db{i}.*dp+sum(lr.*delta{i});
  82.             d=mom*db{i}+sum(lr.*delta{i});
  83.             b{i}=b{i}+d;
  84.             db{i}=d;
  85.         end
  86.         if mod(epoch,500) ==0
  87.             figure(2)
  88.             subplot(1,2,1)
  89.             plot(x',y','o',x',yp{end}','*',x',bestyp{end},'.')
  90.             subplot(1,2,2)
  91.             plot(epoch-500+1:epoch,E(epoch-500+1:epoch))
  92.            
  93.             %lr=lr*0.95;
  94.         end
  95.         %ydelta(epoch)=sum(delta{end});
  96.         %delta{3}
  97.     end
  98.     figure(1)
  99.     SPdist=mod(TN-1,5)+2*5*floor((TN-1)/5)+1;
  100.     SPerror=SPdist+5;
  101.     subplot(2,5,SPdist)
  102.     plot(x',y','o',x',yp{end}','*',x',bestyp{end},'.')
  103.     title(cat(2,'Distribution L=',num2str(L)))
  104.    
  105.     subplot(2,5,SPerror)
  106.     plot(3000:epoch,E(3000:end))
  107.     title(cat(2,'Error L=',num2str(L)))
  108.     [min(E) E(end)]
  109.     E(end)/min(E)
  110. end
  111. % figure (3) ; plot(sort(w{1}))
  112. % figure (4) ; plot(sort(w{2}))
  113. % figure (5) ; plot(sort(w{3}))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement