Fastrail08

K equal partition (BITMASK) (IMPORTANT)

Jul 3rd, 2025 (edited)
144
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 17.17 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2.  
  3. using namespace std;
  4. /*
  5.  
  6. VERY IMPORTANT LINK -
  7. https://chatgpt.com/canvas/shared/68631cabdd608191b64d447a63a1ea45
  8. (DP + BITMASK INTUITION discussion PDF)
  9.  
  10. */
  11.  
  12. // Fill group(LEVELS) one at a time, which essentially makes LEVELS = GROUPS, OPTIONS = all the unused elements in the mask/nums array
  13.  
  14. //Just like filling a sudoku, where LEVELS = cell, and TOTAL OPTIONS = number (1 - 9) , we call with only VALID options and move to the next cell/level.
  15.  
  16. // When all cells are filled, sudoku completed, otherwise backtrack and correct the previous levels with correct valid options
  17.  
  18. // Fill groups one by one, such that each group equals targetSum, if it does, then ONLY MOVE to the next group/level try to fill that again with targetSum, and when all k groups done, partitioning is possible, count it.
  19.  
  20. // But even if exploring all the options, we are not able to fill a group equal to targetSum, backtrack and correct the previously filled groups with some other combinations of elements.
  21.  
  22. // Suppose we have to create k = 4 partitions, & and if we are filling the 3rd group, then we don't need to validate that the 2 previously filled groups equal to targetSum or not, because we are only filling the 3rd group because we were able to fill the last 2 groups SUCCESSFULLY each with sum equal to targetSum, otherwise we would have backtracked from those groups.
  23.  
  24. // So if we are able to reach the base case, we DON'T NEED TO VALIDATE IF ALL THE GROUPS EQUAL TARGETSUM, as we only reach the base when we are able to fill each group with sum = targetSum & formed K groups with all the elements in nums array being utilised (all bits in mask are set).
  25.  
  26. long long generateKGroupsWithTargetSum(int mask, int currSum,
  27.     int groupsFilled, int & targetSum,
  28.     int & k, vector < int > & nums,
  29.     vector < long long > & memo) {
  30.     // base case, ALL elements used (number of set bits in mask = nums.size())
  31.     if (__builtin_popcount(mask) == nums.size()) {
  32.         // all k groups formed with EQUAL SUM (as we were only able to form k groups, because we were able to fill each equal to targetSum) & currSum = 0 i.e., we formed all k groups and using all elements, no extra groups are formed(as currSum = 0, which means the [k + 1]th group has sum = 0)
  33.         //“If groupsFilled == k and currSum == 0, it implies SUM(mask) == totalSum, and the sum has been fully distributed in k × targetSum blocks.”
  34.         if (currSum == 0 && groupsFilled == k) {
  35.             // A valid grouping
  36.             return 1;
  37.         }
  38.         // Invalid case, not all k groups formed or maybe some element remained after filling all k groups, currSum != 0
  39.         return 0;
  40.     }
  41.    
  42.    
  43.     // memo check, if mask is already calculated, return memo[mask];
  44.     if (memo[mask] != -1) {
  45.         return memo[mask];
  46.     }
  47.  
  48.     // LEVELS = groups
  49.     // OPTIONS = unused elements in the mask(ITERATE OVER MASK to figure that out, just as we iterate over 1-9 in sudoku
  50.    
  51.     long long count = 0;
  52.     //“Because all elements are tried in each recursive frame, order-based repetitions are inherently counted — this makes it count permutations, not combinations.”
  53.     // here the loop starts from 0, so mask will get repeated at some levels.
  54.     // i < j, i picked at current level and then j getting picked at some OTHER level
  55.     //j picked at current level and then i getting picked at that OTHER level
  56.     //it will form the same mask, but different path to reach the same, just the order of picking the elements are different(permutations), it is getting counted
  57.     for (int i = 0; i < nums.size(); i++) {
  58.    
  59.         // ith element of the array
  60.         int elementBitIdx = (1 << i);
  61.         int element = nums[i];
  62.        
  63.         // check if ith element in the array is used or not, only use the element for filling the current group that has not been used before
  64.         // checks if the ith bit is set or not
  65.         // if ith bit is ON, '&' would return some non zero value
  66.         // if ith bit is OFF, '&' would return 0
  67.         if ((mask & elementBitIdx) == 0) {
  68.             // As we are generating ONLY VALID GROUPS that's why we don't need to check if all the groups are valid in base case, that remediates the need to maintain an array with sums of all groups explicitly which we were doing before.
  69.             //  But to generate only VALID GROUPS, 2 CONDITIONS NEED TO BE MET
  70.             // USE ONLY THAT ELEMENT THAT IS NOT USED BEFORE,(off in mask) & INCLUDING THAT ELEMENT DOES NOT EXCEED THE TARGETSUM CRITERIA OF THE CURRENT GROUP
  71.             if (currSum + nums[i] <= targetSum) {
  72.                 // we can use the element, as it meets both the criteria,
  73.                 // mark it SET in the mask, before moving forward in recursion
  74.                 int newMask = mask | elementBitIdx; // turn on it's bit in new mask, signifying the element is used
  75.  
  76.                 // if including the current element, completes the current group (currSum == targetSum), move to the new group(groupsFilled + 1) and reset the currSum = 0, as currSum = current sum in the group, and now we will fill the new group, starting it's sum from 0
  77.                 bool isCurrGroupFilled = false;
  78.                 if (currSum + element == targetSum) {
  79.                     isCurrGroupFilled = true;
  80.                 }
  81.                 // count all the valid groupings possible at each level
  82.                 count += generateKGroupsWithTargetSum(newMask,
  83.                         isCurrGroupFilled ? 0 : currSum + element,
  84.                         isCurrGroupFilled ? groupsFilled + 1 : groupsFilled,
  85.                         targetSum, k, nums, memo);
  86.             }
  87.         }
  88.     }
  89.     return memo[mask] = count;
  90. }
  91.  
  92. int main() {
  93.     // your code goes here
  94.     int n, k;
  95.     cin >> n;
  96.     vector<int> nums(n);
  97.     for(int i = 0; i < n; i++){
  98.         cin >> nums[i];
  99.     }
  100.     cin >> k;
  101.    
  102.     // calculate sum
  103.     int totalSum = 0;
  104.     for(int i : nums){
  105.         totalSum += i;
  106.     }
  107.    
  108.     // impossible to partition in k groups, because the sum of all elements itself is not divisible by k
  109.     if(totalSum % k != 0){
  110.         return 0;  
  111.     }
  112.    
  113.     // If SUM(ALL elements) is divisible by k, then only we can think of partitioning all elements in k partitions of equal sum
  114.     // SUM(elements) = 32, k = 4, so targetSum = SUM(elements) / k
  115.     // targetSum = 8 ... each group will have a sum = targetSum = 8
  116.    
  117.     //So there might be a world where we can partition n elements in k groups each with targetSum = totalSum / k
  118.     int targetSum = totalSum / k;
  119.    
  120.    
  121.     //Initially if you see the code the key seems to be memo(mask, currSum, k) as they both are changing and influencing decision at base case OR helping in writing conditional, but still it can be reduced to memo(mask)
  122.     // WHY?????? AS the other 2 parameters can be DERIVED from mask itself.
  123.    
  124.     //mask = a VALID SUBSET CONFIGURATION, currSum = how much full is the current group that is being filled OR sum of the current Group getting filled, k = number of VALID GROUPS FILLED
  125.    
  126.     //k can be derived from mask as => k = SUM(mask) / targetSum;
  127.     //currSum can be derived from mask => currSum = SUM(mask) % targetSum
  128.    
  129.     //SUM(mask) = The CUMULATIVE sum of all the elements that have already been used to form the groups, i.e. all the elements whose corresponding bit in mask is SET/ON.
  130.    
  131.     //Consider k = 4, sum of all elements in the nums array = 32
  132.     //so targetSum = 32/4 = each group should have targetSum = 8
  133.    
  134.     //Suppose at some level/state, a CONFIGURATION represented by mask, elements used sum up to 27, i.e. SUM(mask) = 27
  135.    
  136.     //k = SUM(mask) / targetSum gives the total number of VALID Groups(group sum = targetSum) filled till now OR how much of the SUM(mask) has been used to complete k groups SUCCESSFULLY, which is the what k is, number of groups filled till now
  137.     //SUM(mask)/targetSum => 27 / 8 = 3, it means we have filled 3 groups SUCCESSFULLY(each with sum = 8)
  138.     // portion of SUM(mask) used in filling 3 groups of targetSum = 8 => 3 * 8 = 24
  139.    
  140.     //currSum = SUM(mask) % targetSum gives the remaining portion of SUM(mask) that is forming a new group and is currently filled upto currSum
  141.     // SUM(mask) % targetSum = 27 % 8 = 3, so this 3 is the remaining value of SUM that was not used in the previously filled groups and is currently being used to fill the current group, essentially giving the sum of the current group getting filled, which is currSum
  142.    
  143.     //So we can see currSum and k both are DERIVABLE from mask
  144.    
  145.     //IMPORTANT ************* “Since both currSum and groupsFilled are invariant functions of the mask (given fixed input array and targetSum), memoizing solely on mask captures the entire subproblem state.”
  146.    
  147.    
  148.     //As targetSum is a constant, the other 2 parameters are derivable from mask OR a property of mask
  149.     // Essentitially what it would mean is, NO MATTER WHAT PATH WE TAKE TO REACH A STATE, IF MASK OF BOTH THE STATE IS SAME, then currSum & k would be the same, as they are derived from mask.
  150.     //So 2 states with same mask would also have same value for currSum & k. So there is no need to memoise them, as both of them can be derived from mask.
  151.    
  152.     //so memo(mask, currSum, k) => memo(mask, SUM(mask) % targetSum, SUM(mask) / targetSum) => memo(mask)
  153.     //memo(mask, currSum, k)  S.C. would be O(2^N * targetSum * k) and T.C. = O(2^N * N), looping till N in each state
  154.     // NOW REDUCED TO
  155.     // memo(mask) , S.C would be O(2^N) and T.C. = O(2^N * N). But S.C is greatly reduced.
  156.    
  157.     //At each of the 2^N masks, we loop over up to N elements.
  158.     //Regardless of whether currSum and k are explicit or derived, the same number of recursive calls are made.
  159.     //The difference is how many unique subproblems we store, not how many we compute.
  160.    
  161.     //BUT memoising the memo(mask, currSum, k) would needed to be serialised into some mask|currSum|k which would create many states and reduce the number of same memo states being hit.
  162.     //On the Other hand , memo(mask), and mask is just an integer reduces the number of states heavily
  163.    
  164.     // Even though currSum and k are passed in recursion,
  165.     // for any given mask, their values are uniquely determined as:
  166.     //     currSum = sum(mask) % targetSum
  167.     //     k       = sum(mask) / targetSum
  168.     //
  169.     // So memo(mask, currSum, k) creates redundant entries —
  170.     // multiple entries for the same mask that will always have the same result.
  171.     // Instead, we can just use memo(mask) to compress all those into one,
  172.     // achieving both space savings and better reuse.
  173.    
  174.     // memo[mask][currSum][k] → or flatten to something like memo[serialize(mask, currSum, k)]
  175.     //IT OVERREPRESENTS the number of unique subproblems.
  176.     //It misses memo hits that should be identical
  177.    
  178.     /*
  179.     AS
  180.         groupsFilled = sum(mask) / targetSum;
  181.         currSum     = sum(mask) % targetSum;
  182.  
  183.    
  184.         While using memo(mask, currSum, k)
  185.        
  186.         You're unnecessarily distinguishing between states like:
  187.        
  188.         (mask = 101011, currSum = 3, k = 2)
  189.        
  190.         (mask = 101011, currSum = 3, k = 2) ← this exact state
  191.        
  192.         ✅ These are the same. But also...
  193.        
  194.         You're also accidentally treating these like different:
  195.        
  196.         (mask = 101011, currSum = 3, k = 2)
  197.        
  198.         (mask = 101011, currSum = 3, k = 3)
  199.        
  200.         ...but this case cannot happen, because:
  201.        
  202.         ❗ currSum and k cannot differ for the same mask.
  203.     */
  204.  
  205.    
  206.     // IMPORTANT *************
  207.     // This reduces the memo table size from O(2^N × targetSum × k) to O(2^N)
  208.     // Because (currSum, k) are derivable from mask, we only memoize once per mask
  209.     // This drastically reduces space AND compresses all overlapping subproblems for a given mask
  210.     // Earlier, many (currSum, k) combinations would store separate answers per mask
  211.     // Now, all such combinations are merged into a single memo[mask] entry, increasing reuse
  212.  
  213.    
  214.    
  215.     //What would be the size of memo table => Range of mask => What does it represent? A subset/CONFIGURATION of INDIVIDUAL ELEMENT CHOICE DIAGRAM (recursive tree formed with choices of each element)
  216.    
  217.     // IMPORTANT ************* RANGE = mask ∈ [0, 2^n)
  218.     //Min value of mask = 0, no elements selected
  219.     //Max value of mask = FULL MASK (1 << n) - 1, all elements selected from index 0 to index (n - 1) in nums array, whic means all the bits starting from 0th bit to (n - 1)th bit is set, which gives the max value of mask
  220.     // as (1 << n) - 1, n = 4, 16 elements, ranging from 0th index to 15th index.
  221.  
  222.    
  223.     //So range = 0 to FULL MASK => 0 to (1 << n) - 1 => Representing all the 2^N choices
  224.     //so for n = 4
  225.     // Number of choices = 2^4 = 16 different CONFIGURATION/subsets (include or exclude), So min and max range of mask =
  226.     // Min value = 0 [0 0 0 0 0]
  227.     // Max value = 15[0 1 0 0 0]
  228.    
  229.     //Total 16 different values
  230.     // int shift = 4;
  231.     // //1 left shifted by four positions, Initially [0 0 0 0 1] -> [1 0 0 0 0]
  232.    
  233.     // // IMPORTANT ALWAYS USE bitset<n + 1> TO DEBUG, BITSET helps represents the INTEGER IN BIT FORMAT
  234.     // int mask = (1 << shift);
  235.     // bitset<5> b1(mask);
  236.     // bitset<5> b2(mask - 1);
  237.     // cout << b1 << '\n' << b2 << '\n';
  238.    
  239.     // Memo table of size (1 << n), number of choices for N elements, which makes sure that it captures all the values of mask from range = 0 to (1 << n) - 1
  240.     vector<long long> memo((1 << n), -1);
  241.     cout << generateKGroupsWithTargetSum(0, 0, 0, targetSum, k, nums, memo) << '\n';
  242.    
  243.    
  244.     /*
  245.     OPTIMIZATIONS AND OTHER IMPORTANT NOTES
  246.    
  247.     //SORTING NUMS IN DESCENDING ORDER CAN BE DONE TO ACHIEVE BETTER PRUNING, as larger elements will be used to fill starting groups , which will help reach their targetSum soon.
  248.    
  249.     // The code counts the total number of permutations.
  250.     // Essentitially all the permutations there is, permutations of groups formed, permutations of elements within groups, and permutations of elements even if the group is same (but made from distinct index)
  251.    
  252.     // Combinations can be generated simply by maintaining the lastElementIncludedIndex, as starting from after that, prevents permutations. It is the same as how we prevent permutations in where we have to place r items in n boxes, where n > r
  253.    
  254.     //Suppose if we have to place 1,2 item in 5 boxes,
  255.     //considering levels to be the items,
  256.     and the options to be the empty boxes where they can be placed, So at level of 1, there are 5 empty boxes to place 1
  257.     10000,01000,00100,00010,00001
  258.    
  259.     //Then at level of 2, 4 boxes are there in each path, so if considering path 10000
  260.     // 12000(*), 10200,10020, 10002
  261.    
  262.     // But backtrack to the path where 1 was placed in 01000, continuing on this path to level of 2, options are
  263.     // 21000(*), 01200, 01020, 01002
  264.    
  265.     //BOTH THE STAR marked are permutations of each other.
  266.    
  267.     // IF we wanted only Combinations which means the order doesn't matter, where items are identical/not distinct
  268.     //it would create the same combination, ii000, ii000 , replacing 1 & 2 with i as both are considered identical in combination
  269.     // so we need to prevent that in combination, see what is creating the permutation.
  270.    
  271.     //Allowing 2 to get placed before 1 was already placed.
  272.    
  273.     //So to prevent permutations, just keep track of where the last item was placed, or in the k partition with equal sum what last item was taken, at the next state/level start from next index of where the last item was placed, or which last item was taken
  274.    
  275.     // In mask it was generating permutations as if current mask is 0100 (2nd index element used)
  276.     // we are running a loop over it, taking all the OFF BIT items, AT CURRENT LEVEL, first we take 0th index element, updating the mask to 0101 and recurse to some new NEWLEVEL, where further we take the 3rd index item sequentially updating mask to 1101.
  277.    
  278.     // But when we backtrack to current level, reverse the state of mask to 0100 and take the 3rd index element at CURRENT LEVEL instead of 0th index, we update mask to 1100 and recurse further. Now at that same NEWLEVEL, when we take the 0th index item, we update the mask from 1100 to 1101, which is the same as before.
  279.     // We already have the answer to this, as this mask was visited before, and we just return the count of the total number of groupings possible down this branch, Essentitially counting permutations. We reached the same mask but the order of picking elements was different, but we counted it at the level of caller.
  280.    
  281.     // To avoid this, just start the loop from lastElementIncludedIndex, so when 0th index was picked at the CURRENT LEVEL and recursed further, at the NEWLEVEL, we would have an option to pick the 3rd index as lastElementIncludedIndex = 0th index, enumerating the combination. But when we pick the 3rd index element at CURRENT LEVEL, lastElementIncludedIndex = 3rd index, and when we recurse further to NEWLEVEL, we won't be able to pick the 0th index element as we can only pick elements that are AFTER 3rd index at NEWLEVEL, Essentitially removing the permutation.
  282.    
  283.     */
  284. }
Advertisement
Add Comment
Please, Sign In to add comment