Advertisement
AyushP123

Untitled

Jan 18th, 2022
729
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <iostream>
  2. #include <vector>
  3. using namespace std;
  4.  
  5. vector<vector<int>> conv2d_naive(vector<vector<int>> img, vector<vector<int>> kernel, int stride)
  6. {
  7.     int img_h = img.size();
  8.     int img_w = img[0].size();
  9.     int k_h = kernel.size();
  10.     int k_w = kernel[0].size();
  11.     int output_h = ((img_h - k_h) / stride) + 1;
  12.     int output_w = ((img_w - k_w) / stride) + 1;
  13.     vector<vector<int>> output;
  14.     for(int i = k_h; i < img_h - k_h; i+=stride)
  15.     {
  16.         vector<int> temp;
  17.         for(int j = k_w; j < img_w - k_w; j+=stride)
  18.         {
  19.             int sum = 0;
  20.             for(int i1 = 0; i1 < k_h; i1++)
  21.             {
  22.                 for(int j1 = 0; j1 < k_w; j1++)
  23.                 {
  24.                     sum += img[i - k_h + i1][j - k_w + j1] * kernel[i1][j1];
  25.                 }
  26.             }
  27.             temp.append(sum);
  28.         }
  29.         output.append(temp);
  30.     }
  31.  
  32.     return output;
  33. }
  34.  
  35. vector<vector<int>> conv2d_optimized(vector<int> img, vector<int> kernel, int img_h, int img_w,
  36.                                     int k_h, int k_w, int stride)
  37. {
  38.     int output_h = ((img_h - k_h) / stride) + 1;
  39.     int output_w = ((img_w - k_w) / stride) + 1;
  40.  
  41.     vector<vector<int>> output;
  42.     vector<int> temp;
  43.     int count = 0;
  44.     for(int i = 0; i < img_h * img_w; i+=stride)
  45.     {
  46.         int sum = 0;
  47.         int start = i;
  48.         for(int j = 0; j < k_h * k_w; j++)
  49.         {
  50.             if(j != 0 && j % k_w == 0)
  51.             {
  52.                 start += img_w - k_w;
  53.             }
  54.  
  55.             sum += img[start] * kernel[j];
  56.         }
  57.  
  58.         temp.append(sum)
  59.         if(temp.size() == output_w)
  60.         {
  61.             output.append(temp);
  62.             temp.clear();
  63.         }
  64.     }
  65.    
  66.     return output;
  67. }
Advertisement
Advertisement
Advertisement
RAW Paste Data Copied
Advertisement