Advertisement
Guest User

solution.cpp

a guest
Jan 17th, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.20 KB | None | 0 0
  1. #pragma GCC optimize("Ofast")
  2. #pragma GCC target("avx")
  3.  
  4. #include <iostream>
  5. #include <iomanip>
  6. #include <algorithm>
  7. #include <assert.h>
  8. #include <immintrin.h>
  9.  
  10. namespace Solution {
  11.    
  12.     const int NMAX = 100000;
  13.  
  14.     const int blockSize = 512;
  15.    
  16.     const int MaxNBlocks = (NMAX + blockSize - 1) / blockSize;
  17.  
  18.     float* cntTillBorder[MaxNBlocks];
  19.    
  20.     float* zeroBuffer;
  21.    
  22.     float* arr;
  23.    
  24.     int n, nBlocks;
  25.    
  26.     void alloc(int n_) {
  27.         n = n_;
  28.         nBlocks = (n + blockSize - 1) / blockSize;
  29.         arr = (float*)_mm_malloc(n * sizeof(float), 32);
  30.         zeroBuffer = (float*)_mm_malloc(n * sizeof(float), 32);
  31.         std::fill(zeroBuffer, zeroBuffer+n, 0);
  32.         for (int i = 0; i < nBlocks; ++i) {
  33.             cntTillBorder[i] = (float*)_mm_malloc(n * sizeof(float), 32);
  34.             std::fill(cntTillBorder[i], cntTillBorder[i]+n, 0);
  35.         }
  36.     }
  37.    
  38.     void free() {
  39.         _mm_free(arr);
  40.         for (int i = 0; i < nBlocks; ++i) {
  41.             _mm_free(cntTillBorder[i]);
  42.         }
  43.     }
  44.  
  45.     void precalc() {
  46.         for (int i = 0; i < n; ++i) {
  47.             int value = (int)arr[i];
  48.             for (int b = i / blockSize; b < nBlocks; ++b) {
  49.                 cntTillBorder[b][value]++;
  50.             }
  51.         }
  52.     }
  53.    
  54.     void naiveUpdateItems(int begin, int after, float x, float y) {
  55.         int cnt = 0;
  56.         for (int i = begin; i < after; ++i) {
  57.             cnt += (arr[i] == x);
  58.             arr[i] = (arr[i] == x) ? y : arr[i];
  59.         }
  60.         for (int b = begin / blockSize; b < nBlocks; ++b) {
  61.             cntTillBorder[b][int(x)] -= (float)cnt;
  62.             cntTillBorder[b][int(y)] += (float)cnt;
  63.         }
  64.     }
  65.    
  66.     void modify(int lt, int rt, float x, float y) {
  67.         int bl = lt / blockSize;
  68.         int br = rt / blockSize;
  69.         if (bl == br) {
  70.             naiveUpdateItems(lt, rt+1, x, y);
  71.             return;
  72.         }
  73.         naiveUpdateItems(lt, (bl+1) * blockSize, x, y);
  74.         naiveUpdateItems(br * blockSize, rt+1, x, y);
  75.         __m256 vx = _mm256_set1_ps(x);
  76.         __m256 vy = _mm256_set1_ps(y);
  77.         int changes = 0;
  78.         for (int b = bl+1; b < br; ++b) {
  79.             cntTillBorder[b][int(x)] -= (float)changes;
  80.             cntTillBorder[b][int(y)] += (float)changes;
  81.             float *blockBegin = arr + b * blockSize;
  82.             for (int i = 0; i < blockSize; i += 32) {
  83.                 uint32_t bitmask = 0;
  84.                 for (int j = 0; j < 32; j += 8) {
  85.                     __m256 va = _mm256_load_ps(blockBegin + i + j);
  86.                     __m256 rs = _mm256_cmp_ps(vx, va, _CMP_EQ_OQ);
  87.                     bitmask = (bitmask << 8) | _mm256_movemask_ps(rs);
  88.                     _mm256_maskstore_ps(blockBegin + i + j, _mm256_cvtps_epi32(rs), vy);
  89.                 }
  90.                 changes += __builtin_popcountll(bitmask);
  91.             }
  92.         }
  93.     }
  94.  
  95.     int nth_element(int lt, int rt, int k) {
  96.         int bl = lt / blockSize;
  97.         int br = rt / blockSize;
  98.         float* arrLT = (bl == 0) ? zeroBuffer : cntTillBorder[bl-1];
  99.         float* arrRT = cntTillBorder[br];
  100.         for (int i = rt+1; i < std::min((br+1) * blockSize, n); ++i) { arrRT[(int)arr[i]]--; }
  101.         for (int i = bl * blockSize; i < lt; ++i) { arrLT[(int)arr[i]]++; }
  102.         int last = -1;
  103.         for (int i = 0; i + 31 < n; i += 32) {
  104.             __m256 sum = _mm256_setzero_ps(), vr, vl;
  105.             for (int j = 0; j < 32; j += 8) {
  106.                 vr = _mm256_load_ps(arrRT+i+j);
  107.                 vl = _mm256_load_ps(arrLT+i+j);
  108.                 sum = _mm256_add_ps(sum, _mm256_sub_ps(vr,vl));
  109.             }
  110.             alignas(32) static float tmp[8];
  111.             _mm256_store_ps(tmp, sum);
  112.             int cnt = 0;
  113.             for (int j = 0; j < 8; ++j) { cnt += (int)tmp[j]; }
  114.             if (cnt >= k) {
  115.                 last = i-1;
  116.                 break;
  117.             }
  118.             k -= cnt;
  119.             last = i + 31;
  120.         }
  121.         int cnt = 0;
  122.         while (cnt < k) { last++; cnt += int(arrRT[last] - arrLT[last]); }
  123.         for (int i = bl * blockSize; i < lt; ++i) { arrLT[(int)arr[i]]--; }
  124.         for (int i = rt+1; i < std::min((br+1) * blockSize, n); ++i) { arrRT[(int)arr[i]]++; }
  125.         return last;
  126.     }
  127. }
  128.  
  129. char getChar() {
  130.     static const int SIZE = 1 << 16;
  131.     static char buffer[SIZE];
  132.     static int pos = 0;
  133.     static int size = 0;
  134.     if (pos == size) {
  135.         size = (int)fread(buffer, 1, SIZE, stdin),
  136.         pos = 0;
  137.     }
  138.     if (pos == size) {
  139.         return EOF;
  140.     }
  141.     return buffer[pos++];
  142. }
  143.  
  144. template<typename T>
  145. T getInt() {
  146.     char c = '?';
  147.     while (!(c == '-' || ('0' <= c && c <= '9'))) { c = getChar(); }
  148.     bool pos = true;
  149.     if (c == '-') { pos = false; c = getChar(); }
  150.     T ret = 0;
  151.     while ('0' <= c && c <= '9') { (ret *= 10) += (c - '0'); c = getChar(); }
  152.     return pos ? ret : -ret;
  153. }
  154.  
  155. void putChar(char c) {
  156.     static const int SIZE = 1 << 16;
  157.     static char buffer[SIZE];
  158.     static int size = 0;
  159.     if (size == SIZE || c == EOF) {
  160.         fwrite(buffer, 1, size, stdout),
  161.         size = 0;
  162.     }
  163.     if (c != EOF) { buffer[size++] = c; }
  164. }
  165.  
  166. template<typename T>
  167. void putInt(T value) {
  168.     bool pos = true;
  169.     if (value < 0) { pos = false; value = -value; }
  170.     static char buf[24];
  171.     int size = 0;
  172.     do { buf[size++] = char(value % 10 + '0'); value /= 10; } while (value > 0);
  173.     if (!pos) { buf[size++] = '-'; }
  174.     while (size--) { putChar(buf[size]); }
  175. }
  176.  
  177. int main() {
  178.     int n = getInt<int>(), q = getInt<int>();
  179.     Solution::alloc(n);
  180.     for (int i = 0; i < n; ++i) {
  181.         Solution::arr[i] = (float)(getInt<int>()-1);
  182.     }
  183.     Solution::precalc();
  184.    
  185.     while (q--) {
  186.         int t = getInt<int>();
  187.         if (t == 1) {
  188.             int lt = getInt<int>()-1, rt = getInt<int>()-1, x = getInt<int>()-1, y = getInt<int>()-1;
  189.             Solution::modify(lt, rt, (float)x, (float)y);
  190.         } else {
  191.             int lt = getInt<int>()-1, rt = getInt<int>()-1, k = getInt<int>();
  192.             putInt(Solution::nth_element(lt, rt, k)+1);
  193.             putChar('\n');
  194.             assert(t == 2);
  195.         }
  196.     }
  197.     putChar(EOF);
  198.     Solution::free();
  199.     return 0;
  200. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement