daily pastebin goal
41%
SHARE
TWEET

Untitled

a guest Mar 23rd, 2019 55 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. from torch.jit import ScriptModule, script_method
  4. from typing import List
  5.  
  6.  
  7. class BatchNorm(ScriptModule):
  8.     __constants__ = ['mom', 'eps']
  9.  
  10.     def __init__(self, nf, mom=0.9, eps=1e-5):
  11.         super().__init__()
  12.         self.mom, self.eps = mom, eps
  13.         self.mults = nn.Parameter(torch.ones(nf, 1, 1))
  14.         self.adds = nn.Parameter(torch.zeros(nf, 1, 1))
  15.         # self.means = [torch.zeros(1, nf, 1, 1)]
  16.         # self.vars = [torch.ones(1, nf, 1, 1)]
  17.         self.means = torch.jit.Attribute([torch.zeros(1, nf, 1, 1)], List[torch.Tensor])
  18.         self.vars = torch.jit.Attribute([torch.ones(1, nf, 1, 1)], List[torch.Tensor])
  19.         # self.register_buffer('vars',  torch.ones (1,nf,1,1))
  20.         # self.register_buffer('means', torch.zeros(1,nf,1,1))
  21.  
  22.     @script_method
  23.     def update_stats(self, x):
  24.         m = x.mean((0, 2, 3), keepdim=True)
  25.         v = x.var((0, 2, 3), keepdim=True)
  26.         # self.means.detach_()
  27.         # self.vars.detach_()
  28.         self.means[0] = self.means[0] * self.mom + m*(1-self.mom)
  29.         self.vars[0] = self.vars[0] * self.mom + v*(1-self.mom)
  30.  
  31.     @script_method
  32.     def forward(self, x):
  33.         if self.training:
  34.             self.update_stats(x)
  35.         x = (x-self.means[0]) / ((self.vars[0] + self.eps).sqrt())
  36.         return x*self.mults + self.adds
  37.  
  38.  
  39. m = BatchNorm(100)
  40.  
  41. graph = m.graph_for(torch.randn(20, 100))
  42. print(graph)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top