Advertisement
nasarouf

kattis:moretriangles

Mar 7th, 2017
57
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 1.91 KB | None | 0 0
  1. //AC Just a few more triangles! https://open.kattis.com/problems/moretriangles
  2. #include <iostream>
  3. #include <cmath>
  4. #include <vector>
  5. #include <ccomplex>
  6. #include <complex>
  7. #include <algorithm>
  8. #include <map>
  9. using namespace std;
  10.  
  11. typedef double ld;
  12. const ld PI = acos(-1.L);
  13. typedef complex<ld> Complex;
  14.  
  15. //precomputing these saves 50% time
  16. map<pair<int, int>, Complex> wns;
  17. void FFTinit(int N){
  18.     for (int i = 1; i <= N; i *= 2)
  19.         for (int dir = -1; dir <= 1; dir += 2)
  20.             wns[make_pair(i, dir)] = polar((ld)1, dir * 2 * PI / i);
  21. }
  22.  
  23. // *** Fast Fourier Transform (Recursive), ubc codearchive, heavily modified... ***
  24. //array pointers instead of vectors, calculations done in-place
  25. void rfft(int n, Complex* a, Complex* y, int dir = 1, int stride = 1) {
  26.     if (n == 1) { y[0] = a[0]; return; }
  27.     Complex wn = wns[make_pair(n, dir)], w = 1;
  28. //  Complex wn = polar((ld)1, dir * 2 * PI / n), w = 1; //precomputed instead
  29.     rfft(n / 2, a, y, dir, stride * 2);
  30.     rfft(n / 2, a + stride, y + n / 2, dir, stride * 2);
  31.     for (int k = 0; k < n / 2; k++, w *= wn) {
  32.         Complex y1 = y[k] + w*y[n / 2 + k];
  33.         Complex y2 = y[k] - w*y[n / 2 + k];
  34.         y[k + n / 2] = y2;
  35.         y[k] = y1;
  36.     }
  37. }
  38. //end code archive
  39. void ifft(int n, Complex* a, Complex* y) {
  40.     for (int i = 0; i < n; i++) a[i] /= ld(n);
  41.     rfft(n, a, y, -1);
  42. }
  43.  
  44. typedef long long ll;
  45.  
  46. int main(){
  47.     ll n, res=0;
  48.     cin >> n; ll sz = n + n - 1;
  49.     int N = 1 << int(ceil(log2(sz)));
  50.     FFTinit(N);
  51.     vector<Complex> v(N, 0), c(N, 0), tmp(N, 0);
  52.     for (ll i = 1; i < n; i++) v[(i*i) % n] = v[(i*i) % n] + (ld)1;
  53.     //convolution
  54.     for (ll i = 0; i < n; i++) tmp[i] = v[i];
  55.     rfft(N, &tmp[0], &c[0]);
  56.     for (ll i = 0; i < N; i++) tmp[i] = c[i] * c[i];
  57.     ifft(N, &tmp[0], &c[0]);
  58.     //remove double count
  59.     for (ll i = 0; i < N; i++) res += (v[i%n]).real()*(int)(.5 + c[i].real());
  60.     for (ll i = 1; i < n; i++) res += v[(2 * i*i) % n].real();
  61.     cout << res/2LL << endl;
  62.     return 0;
  63. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement