Advertisement
Guest User

Untitled

a guest
Apr 6th, 2020
171
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.26 KB | None | 0 0
  1. import torch
  2. from torch import nn
  3.  
  4. class CovolutionalModel(nn.Module):
  5.   def __init__(self, in_channels, conv1_width, ..., fc1_width, class_vount):
  6.     self.conv1 = nn.Conv2d(in_channels, conv1_width, kernel_size=5, stride=1, padding=2, bias=True)
  7.     # ostatak konvolucijskih slojeva i slojeva sažimanja
  8.     ...
  9.     self.fc1 = nn.Linear(..., fc1_width, bias=True)
  10.     self.fc_logits = nn.Linear(fc1_width, class_count, bias=True)
  11.    
  12.     # parametri suvveć inicijalizirani pozivima Conv2d i Linear
  13.     # ali mi radimo malo drugačiju inicijalizaciju
  14.     self.reset_parameters()
  15.  
  16.   def reset_parmeters(self)
  17.     for m in self.modules():
  18.       if isinstance(m, nn.Conv2d):
  19.         nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
  20.         nn.init.constant_(m.bias, 0)
  21.       elif isinstance(m, nn.Linear) and m is not self.fc_logits:
  22.         nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
  23.         nn.init.constant_(m.bias, 0)
  24.     self.fc_logits.reset_parameters()
  25.  
  26.   def forward(self, x):
  27.     h = self.conv1(x)
  28.     h = torch.relu(h)  # može i h.relu() ili nn.functional.relu(h)
  29.     ...
  30.     h = x.view(x.shape[0], -1)
  31.     h = self.fc1(x)
  32.     h = torch.relu(h)
  33.     logits = self.fc_logits(h)
  34.     return logits
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement