Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- using namespace std;
- /*
- VERY IMPORTANT LINK -
- https://chatgpt.com/canvas/shared/68631cabdd608191b64d447a63a1ea45
- (DP + BITMASK INTUITION discussion PDF)
- */
- // Fill group(LEVELS) one at a time, which essentially makes LEVELS = GROUPS, OPTIONS = all the unused elements in the mask/nums array
- //Just like filling a sudoku, where LEVELS = cell, and TOTAL OPTIONS = number (1 - 9) , we call with only VALID options and move to the next cell/level.
- // When all cells are filled, sudoku completed, otherwise backtrack and correct the previous levels with correct valid options
- // Fill groups one by one, such that each group equals targetSum, if it does, then ONLY MOVE to the next group/level try to fill that again with targetSum, and when all k groups done, partitioning is possible, count it.
- // But even if exploring all the options, we are not able to fill a group equal to targetSum, backtrack and correct the previously filled groups with some other combinations of elements.
- // Suppose we have to create k = 4 partitions, & and if we are filling the 3rd group, then we don't need to validate that the 2 previously filled groups equal to targetSum or not, because we are only filling the 3rd group because we were able to fill the last 2 groups SUCCESSFULLY each with sum equal to targetSum, otherwise we would have backtracked from those groups.
- // So if we are able to reach the base case, we DON'T NEED TO VALIDATE IF ALL THE GROUPS EQUAL TARGETSUM, as we only reach the base when we are able to fill each group with sum = targetSum & formed K groups with all the elements in nums array being utilised (all bits in mask are set).
- long long generateKGroupsWithTargetSum(int mask, int currSum,
- int groupsFilled, int & targetSum,
- int & k, vector < int > & nums,
- vector < long long > & memo) {
- // base case, ALL elements used (number of set bits in mask = nums.size())
- if (__builtin_popcount(mask) == nums.size()) {
- // all k groups formed with EQUAL SUM (as we were only able to form k groups, because we were able to fill each equal to targetSum) & currSum = 0 i.e., we formed all k groups and using all elements, no extra groups are formed(as currSum = 0, which means the [k + 1]th group has sum = 0)
- //“If groupsFilled == k and currSum == 0, it implies SUM(mask) == totalSum, and the sum has been fully distributed in k × targetSum blocks.”
- if (currSum == 0 && groupsFilled == k) {
- // A valid grouping
- return 1;
- }
- // Invalid case, not all k groups formed or maybe some element remained after filling all k groups, currSum != 0
- return 0;
- }
- // memo check, if mask is already calculated, return memo[mask];
- if (memo[mask] != -1) {
- return memo[mask];
- }
- // LEVELS = groups
- // OPTIONS = unused elements in the mask(ITERATE OVER MASK to figure that out, just as we iterate over 1-9 in sudoku
- long long count = 0;
- //“Because all elements are tried in each recursive frame, order-based repetitions are inherently counted — this makes it count permutations, not combinations.”
- // here the loop starts from 0, so mask will get repeated at some levels.
- // i < j, i picked at current level and then j getting picked at some OTHER level
- //j picked at current level and then i getting picked at that OTHER level
- //it will form the same mask, but different path to reach the same, just the order of picking the elements are different(permutations), it is getting counted
- for (int i = 0; i < nums.size(); i++) {
- // ith element of the array
- int elementBitIdx = (1 << i);
- int element = nums[i];
- // check if ith element in the array is used or not, only use the element for filling the current group that has not been used before
- // checks if the ith bit is set or not
- // if ith bit is ON, '&' would return some non zero value
- // if ith bit is OFF, '&' would return 0
- if ((mask & elementBitIdx) == 0) {
- // As we are generating ONLY VALID GROUPS that's why we don't need to check if all the groups are valid in base case, that remediates the need to maintain an array with sums of all groups explicitly which we were doing before.
- // But to generate only VALID GROUPS, 2 CONDITIONS NEED TO BE MET
- // USE ONLY THAT ELEMENT THAT IS NOT USED BEFORE,(off in mask) & INCLUDING THAT ELEMENT DOES NOT EXCEED THE TARGETSUM CRITERIA OF THE CURRENT GROUP
- if (currSum + nums[i] <= targetSum) {
- // we can use the element, as it meets both the criteria,
- // mark it SET in the mask, before moving forward in recursion
- int newMask = mask | elementBitIdx; // turn on it's bit in new mask, signifying the element is used
- // if including the current element, completes the current group (currSum == targetSum), move to the new group(groupsFilled + 1) and reset the currSum = 0, as currSum = current sum in the group, and now we will fill the new group, starting it's sum from 0
- bool isCurrGroupFilled = false;
- if (currSum + element == targetSum) {
- isCurrGroupFilled = true;
- }
- // count all the valid groupings possible at each level
- count += generateKGroupsWithTargetSum(newMask,
- isCurrGroupFilled ? 0 : currSum + element,
- isCurrGroupFilled ? groupsFilled + 1 : groupsFilled,
- targetSum, k, nums, memo);
- }
- }
- }
- return memo[mask] = count;
- }
- int main() {
- // your code goes here
- int n, k;
- cin >> n;
- vector<int> nums(n);
- for(int i = 0; i < n; i++){
- cin >> nums[i];
- }
- cin >> k;
- // calculate sum
- int totalSum = 0;
- for(int i : nums){
- totalSum += i;
- }
- // impossible to partition in k groups, because the sum of all elements itself is not divisible by k
- if(totalSum % k != 0){
- return 0;
- }
- // If SUM(ALL elements) is divisible by k, then only we can think of partitioning all elements in k partitions of equal sum
- // SUM(elements) = 32, k = 4, so targetSum = SUM(elements) / k
- // targetSum = 8 ... each group will have a sum = targetSum = 8
- //So there might be a world where we can partition n elements in k groups each with targetSum = totalSum / k
- int targetSum = totalSum / k;
- //Initially if you see the code the key seems to be memo(mask, currSum, k) as they both are changing and influencing decision at base case OR helping in writing conditional, but still it can be reduced to memo(mask)
- // WHY?????? AS the other 2 parameters can be DERIVED from mask itself.
- //mask = a VALID SUBSET CONFIGURATION, currSum = how much full is the current group that is being filled OR sum of the current Group getting filled, k = number of VALID GROUPS FILLED
- //k can be derived from mask as => k = SUM(mask) / targetSum;
- //currSum can be derived from mask => currSum = SUM(mask) % targetSum
- //SUM(mask) = The CUMULATIVE sum of all the elements that have already been used to form the groups, i.e. all the elements whose corresponding bit in mask is SET/ON.
- //Consider k = 4, sum of all elements in the nums array = 32
- //so targetSum = 32/4 = each group should have targetSum = 8
- //Suppose at some level/state, a CONFIGURATION represented by mask, elements used sum up to 27, i.e. SUM(mask) = 27
- //k = SUM(mask) / targetSum gives the total number of VALID Groups(group sum = targetSum) filled till now OR how much of the SUM(mask) has been used to complete k groups SUCCESSFULLY, which is the what k is, number of groups filled till now
- //SUM(mask)/targetSum => 27 / 8 = 3, it means we have filled 3 groups SUCCESSFULLY(each with sum = 8)
- // portion of SUM(mask) used in filling 3 groups of targetSum = 8 => 3 * 8 = 24
- //currSum = SUM(mask) % targetSum gives the remaining portion of SUM(mask) that is forming a new group and is currently filled upto currSum
- // SUM(mask) % targetSum = 27 % 8 = 3, so this 3 is the remaining value of SUM that was not used in the previously filled groups and is currently being used to fill the current group, essentially giving the sum of the current group getting filled, which is currSum
- //So we can see currSum and k both are DERIVABLE from mask
- //IMPORTANT ************* “Since both currSum and groupsFilled are invariant functions of the mask (given fixed input array and targetSum), memoizing solely on mask captures the entire subproblem state.”
- //As targetSum is a constant, the other 2 parameters are derivable from mask OR a property of mask
- // Essentitially what it would mean is, NO MATTER WHAT PATH WE TAKE TO REACH A STATE, IF MASK OF BOTH THE STATE IS SAME, then currSum & k would be the same, as they are derived from mask.
- //So 2 states with same mask would also have same value for currSum & k. So there is no need to memoise them, as both of them can be derived from mask.
- //so memo(mask, currSum, k) => memo(mask, SUM(mask) % targetSum, SUM(mask) / targetSum) => memo(mask)
- //memo(mask, currSum, k) S.C. would be O(2^N * targetSum * k) and T.C. = O(2^N * N), looping till N in each state
- // NOW REDUCED TO
- // memo(mask) , S.C would be O(2^N) and T.C. = O(2^N * N). But S.C is greatly reduced.
- //At each of the 2^N masks, we loop over up to N elements.
- //Regardless of whether currSum and k are explicit or derived, the same number of recursive calls are made.
- //The difference is how many unique subproblems we store, not how many we compute.
- //BUT memoising the memo(mask, currSum, k) would needed to be serialised into some mask|currSum|k which would create many states and reduce the number of same memo states being hit.
- //On the Other hand , memo(mask), and mask is just an integer reduces the number of states heavily
- // Even though currSum and k are passed in recursion,
- // for any given mask, their values are uniquely determined as:
- // currSum = sum(mask) % targetSum
- // k = sum(mask) / targetSum
- //
- // So memo(mask, currSum, k) creates redundant entries —
- // multiple entries for the same mask that will always have the same result.
- // Instead, we can just use memo(mask) to compress all those into one,
- // achieving both space savings and better reuse.
- // memo[mask][currSum][k] → or flatten to something like memo[serialize(mask, currSum, k)]
- //IT OVERREPRESENTS the number of unique subproblems.
- //It misses memo hits that should be identical
- /*
- AS
- groupsFilled = sum(mask) / targetSum;
- currSum = sum(mask) % targetSum;
- While using memo(mask, currSum, k)
- You're unnecessarily distinguishing between states like:
- (mask = 101011, currSum = 3, k = 2)
- (mask = 101011, currSum = 3, k = 2) ← this exact state
- ✅ These are the same. But also...
- You're also accidentally treating these like different:
- (mask = 101011, currSum = 3, k = 2)
- (mask = 101011, currSum = 3, k = 3)
- ...but this case cannot happen, because:
- ❗ currSum and k cannot differ for the same mask.
- */
- // IMPORTANT *************
- // This reduces the memo table size from O(2^N × targetSum × k) to O(2^N)
- // Because (currSum, k) are derivable from mask, we only memoize once per mask
- // This drastically reduces space AND compresses all overlapping subproblems for a given mask
- // Earlier, many (currSum, k) combinations would store separate answers per mask
- // Now, all such combinations are merged into a single memo[mask] entry, increasing reuse
- //What would be the size of memo table => Range of mask => What does it represent? A subset/CONFIGURATION of INDIVIDUAL ELEMENT CHOICE DIAGRAM (recursive tree formed with choices of each element)
- // IMPORTANT ************* RANGE = mask ∈ [0, 2^n)
- //Min value of mask = 0, no elements selected
- //Max value of mask = FULL MASK (1 << n) - 1, all elements selected from index 0 to index (n - 1) in nums array, whic means all the bits starting from 0th bit to (n - 1)th bit is set, which gives the max value of mask
- // as (1 << n) - 1, n = 4, 16 elements, ranging from 0th index to 15th index.
- //So range = 0 to FULL MASK => 0 to (1 << n) - 1 => Representing all the 2^N choices
- //so for n = 4
- // Number of choices = 2^4 = 16 different CONFIGURATION/subsets (include or exclude), So min and max range of mask =
- // Min value = 0 [0 0 0 0 0]
- // Max value = 15[0 1 0 0 0]
- //Total 16 different values
- // int shift = 4;
- // //1 left shifted by four positions, Initially [0 0 0 0 1] -> [1 0 0 0 0]
- // // IMPORTANT ALWAYS USE bitset<n + 1> TO DEBUG, BITSET helps represents the INTEGER IN BIT FORMAT
- // int mask = (1 << shift);
- // bitset<5> b1(mask);
- // bitset<5> b2(mask - 1);
- // cout << b1 << '\n' << b2 << '\n';
- // Memo table of size (1 << n), number of choices for N elements, which makes sure that it captures all the values of mask from range = 0 to (1 << n) - 1
- vector<long long> memo((1 << n), -1);
- cout << generateKGroupsWithTargetSum(0, 0, 0, targetSum, k, nums, memo) << '\n';
- /*
- OPTIMIZATIONS AND OTHER IMPORTANT NOTES
- //SORTING NUMS IN DESCENDING ORDER CAN BE DONE TO ACHIEVE BETTER PRUNING, as larger elements will be used to fill starting groups , which will help reach their targetSum soon.
- // The code counts the total number of permutations.
- // Essentitially all the permutations there is, permutations of groups formed, permutations of elements within groups, and permutations of elements even if the group is same (but made from distinct index)
- // Combinations can be generated simply by maintaining the lastElementIncludedIndex, as starting from after that, prevents permutations. It is the same as how we prevent permutations in where we have to place r items in n boxes, where n > r
- //Suppose if we have to place 1,2 item in 5 boxes,
- //considering levels to be the items,
- and the options to be the empty boxes where they can be placed, So at level of 1, there are 5 empty boxes to place 1
- 10000,01000,00100,00010,00001
- //Then at level of 2, 4 boxes are there in each path, so if considering path 10000
- // 12000(*), 10200,10020, 10002
- // But backtrack to the path where 1 was placed in 01000, continuing on this path to level of 2, options are
- // 21000(*), 01200, 01020, 01002
- //BOTH THE STAR marked are permutations of each other.
- // IF we wanted only Combinations which means the order doesn't matter, where items are identical/not distinct
- //it would create the same combination, ii000, ii000 , replacing 1 & 2 with i as both are considered identical in combination
- // so we need to prevent that in combination, see what is creating the permutation.
- //Allowing 2 to get placed before 1 was already placed.
- //So to prevent permutations, just keep track of where the last item was placed, or in the k partition with equal sum what last item was taken, at the next state/level start from next index of where the last item was placed, or which last item was taken
- // In mask it was generating permutations as if current mask is 0100 (2nd index element used)
- // we are running a loop over it, taking all the OFF BIT items, AT CURRENT LEVEL, first we take 0th index element, updating the mask to 0101 and recurse to some new NEWLEVEL, where further we take the 3rd index item sequentially updating mask to 1101.
- // But when we backtrack to current level, reverse the state of mask to 0100 and take the 3rd index element at CURRENT LEVEL instead of 0th index, we update mask to 1100 and recurse further. Now at that same NEWLEVEL, when we take the 0th index item, we update the mask from 1100 to 1101, which is the same as before.
- // We already have the answer to this, as this mask was visited before, and we just return the count of the total number of groupings possible down this branch, Essentitially counting permutations. We reached the same mask but the order of picking elements was different, but we counted it at the level of caller.
- // To avoid this, just start the loop from lastElementIncludedIndex, so when 0th index was picked at the CURRENT LEVEL and recursed further, at the NEWLEVEL, we would have an option to pick the 3rd index as lastElementIncludedIndex = 0th index, enumerating the combination. But when we pick the 3rd index element at CURRENT LEVEL, lastElementIncludedIndex = 3rd index, and when we recurse further to NEWLEVEL, we won't be able to pick the 0th index element as we can only pick elements that are AFTER 3rd index at NEWLEVEL, Essentitially removing the permutation.
- */
- }
Advertisement
Add Comment
Please, Sign In to add comment