Advertisement
Guest User

Untitled

a guest
Jun 7th, 2021
49
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def scattering2d(x, pad, unpad, backend, J, L, phi, psi, max_order,
  2.         out_type='array'):
  3.     subsample_fourier = backend.subsample_fourier
  4.     modulus = backend.modulus
  5.     rfft = backend.rfft
  6.     ifft = backend.ifft
  7.     irfft = backend.irfft
  8.     cdgmm = backend.cdgmm
  9.     concatenate = backend.concatenate
  10.  
  11.     # Define lists for output.
  12.     out_S_0, out_S_1, out_S_2 = [], [], []
  13.  
  14.     U_r = pad(x)
  15.  
  16.     U_0_c = rfft(U_r)
  17.  
  18.     # First low pass filter
  19.     U_1_c = cdgmm(U_0_c, phi[0])
  20.     # U_1_c = subsample_fourier(U_1_c, k=2 ** J)
  21.  
  22.     S_0 = irfft(U_1_c)
  23.     S_0 = unpad(S_0)
  24.  
  25.     out_S_0.append({'coef': S_0,
  26.                     'j': (),
  27.                     'theta': ()})
  28.  
  29.     for n1 in range(len(psi)):
  30.         j1 = psi[n1]['j']
  31.         theta1 = psi[n1]['theta']
  32.  
  33.         U_1_c = cdgmm(U_0_c, psi[n1][0])
  34.         # if j1 > 0:
  35.         #     U_1_c = subsample_fourier(U_1_c, k=2 ** j1)
  36.         U_1_c = ifft(U_1_c)
  37.         U_1_c = modulus(U_1_c)
  38.         U_1_c = rfft(U_1_c)
  39.  
  40.         # Second low pass filter
  41.         # S_1_c = cdgmm(U_1_c, phi[j1])
  42.         S_1_c = U_1_c
  43.         # S_1_c = subsample_fourier(S_1_c, k=2 ** (J - j1))
  44.  
  45.         S_1_r = irfft(S_1_c)
  46.         S_1_r = unpad(S_1_r)
  47.  
  48.         out_S_1.append({'coef': S_1_r,
  49.                         'j': (j1,),
  50.                         'theta': (theta1,)})
  51.  
  52.         if max_order < 2:
  53.             continue
  54.         for n2 in range(len(psi)):
  55.             j2 = psi[n2]['j']
  56.             theta2 = psi[n2]['theta']
  57.  
  58.             if j2 <= j1:
  59.                 continue
  60.  
  61.             U_2_c = cdgmm(U_1_c, psi[n2][0])
  62.             # U_2_c = subsample_fourier(U_2_c, k=2 ** (j2 - j1))
  63.             U_2_c = ifft(U_2_c)
  64.             U_2_c = modulus(U_2_c)
  65.             U_2_c = rfft(U_2_c)
  66.  
  67.             # Third low pass filter
  68.             # S_2_c = cdgmm(U_2_c, phi[j2])
  69.             # S_2_c = subsample_fourier(S_2_c, k=2 ** (J - j2))
  70.             S_2_c = U_2_c
  71.  
  72.             S_2_r = irfft(S_2_c)
  73.             S_2_r = unpad(S_2_r)
  74.  
  75.             out_S_2.append({'coef': S_2_r,
  76.                             'j': (j1, j2),
  77.                             'theta': (theta1, theta2)})
  78.  
  79.     out_S = []
  80.     out_S.extend(out_S_0)
  81.     out_S.extend(out_S_1)
  82.     out_S.extend(out_S_2)
  83.  
  84.     if out_type == 'array':
  85.         out_S = concatenate([x['coef'] for x in out_S])
  86.  
  87.     return out_S
  88.  
  89.  
  90. __all__ = ['scattering2d']
  91.  
Advertisement
Advertisement
Advertisement
RAW Paste Data Copied
Advertisement