Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- #define MAX_K 2001
- using namespace std;
- typedef long long ll;
- class CandyDrawing{
- public:
- ll mod, memo[4*MAX_K][MAX_K];
- ll mul_inv(ll a, ll b)
- {
- ll b0 = b, t, q;
- ll x0 = 0, x1 = 1;
- if (b == 1) return 1;
- while (a > 1) {
- q = a / b;
- t = b, b = a % b, a = t;
- t = x0, x0 = x1 - q * x0, x1 = t;
- }
- if (x1 < 0) x1 += b0;
- return x1;
- }
- ll lagrange(ll p, vector<ll> &x, vector<ll> &y){
- ll num, den, ans = 0;
- for(size_t i = 0; i < x.size(); i++){
- num = den = 1;
- for(size_t j = 0; j < x.size(); j++) if(j != i){
- num = (num*(p-x[j]))%mod;
- den = (den*(x[i]-x[j]))%mod;
- }
- ans = (((num*mul_inv(den,mod))%mod)*y[i] + ans)%mod;
- }
- return ans + (ans < 0 ? mod : 0);
- }
- ll dp(int n, int k){
- if(k == 0) return memo[n][k] = 1;
- if(n < k) return memo[n][k] = 0;
- if(memo[n][k] != -1) return memo[n][k];
- return memo[n][k] = ((n*dp(n-1,k-1))%mod + dp(n-1,k))%mod;
- }
- int findProbability(int n, int k, int mod_){
- mod = mod_;
- memset(memo,-1,sizeof memo);
- if(n <= 3*k+1) return dp(n,k);
- dp(3*k+1, k);
- vector<ll> x(2*k+2, 0), y(2*k+2,0);
- for(int i = 0; i <= 2*k+1; i++){
- x[i] = i + k;
- y[i] = memo[i+k][k];
- cout << x[i] << ' ' << y[i] << '\n';
- }
- return lagrange(n,x,y);
- }
- };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement