Advertisement
MathQ_

Untitled

Mar 28th, 2021
145
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.03 KB | None | 0 0
  1. def pooling(X):    
  2.         pool = nn.MaxPool2d(2, 2, return_indices=True)
  3.         unpool = nn.MaxUnpool2d(2, 2)
  4.         all_indices = []
  5.         all_outputs = []
  6.        
  7.         # Encoder
  8.         self.all_outputs.append(X.shape)
  9.  
  10.         X, indices = pool(X)
  11.         self.all_indices.append(indices)
  12.         self.all_outputs.append(X.shape)
  13.  
  14.         X, indices = pool(X)
  15.         self.all_indices.append(indices)
  16.         self.all_outputs.append(X.shape)
  17.  
  18.         X, indices = pool(X)
  19.         self.all_indices.append(indices)
  20.         self.all_outputs.append(X.shape)
  21.  
  22.         X, indices = pool(X)
  23.         self.all_indices.append(indices)
  24.         #self.all_outputs.append(X.shape)
  25.  
  26.         # Decoder
  27.         X = self.unpool(X, self.all_indices[-1], output_size=self.all_outputs[-1])
  28.         X = self.unpool(X, self.all_indices[-2], output_size=self.all_outputs[-2])
  29.         X = self.unpool(X, self.all_indices[-3], output_size=self.all_outputs[-3])
  30.         X = self.unpool(X, self.all_indices[-4], output_size=self.all_outputs[-4])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement