Post

Bit Mask

Bit Mask

Tricks

  • (x >> i) & 1: get the i-th bit in state x
  • (x & y) == x: check if x is a subset of y
  • (x & (x >> 1)) == 0: check if there are no adjancent valid states in x

Groups of Strings

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
private Map<Integer, Integer> map = new HashMap<>();
private Set<Integer> visited = new HashSet<>();

public int[] groupStrings(String[] words) {
    int groups = 0, maxSize = 0;
    for (String w : words) {
        int mask = 0;
        for (char ch : w.toCharArray()) {
            mask |= 1 << (ch - 'a');
        }
        map.put(mask, map.getOrDefault(mask, 0) + 1);
    }

    for (int k : map.keySet()) {
        int size = dfs(k);
        maxSize = Math.max(maxSize, size);
        if (size > 0) {
            groups++;
        }
    }
    return new int[]{groups, maxSize};
}

private int dfs(int mask) {
    if (!map.containsKey(mask) || !visited.add(mask)) {
        return 0;
    }

    int count = map.get(mask);
    for (int i = 0; i < 26; i++) {
        // add/delete: flips one bit
        count += dfs(mask ^ (1 << i));
        for (int j = i + 1; j < 26; j++) {
            // replace: flips two bits with different values
            if (((mask >> i) & 1) != ((mask >> j) & 1)) {
                count += dfs(mask ^ (1 << i) ^ (1 << j));
            }
        }
    }
    return count;
}

Enumeration

Iterator for Combination

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class CombinationIterator {
    private int bitmask = 0, n = 0, k = 0;
    private String characters;

    public CombinationIterator(String characters, int combinationLength) {
        this.characters = characters;
        this.n = characters.length();
        this.k = combinationLength;

        // 1(k)0(n - k)
        this.bitmask = (1 << n) - (1 << (n - k));
    }

    public String next() {
        // converts bitmask into combination
        // 111 --> "abc", 000 --> ""
        // 110 --> "ab", 101 --> "ac", 011 --> "bc"
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i++) {
            if ((bitmask & (1 << (n - i - 1))) != 0) {
                sb.append(characters.charAt(i));
            }
        }

        // generates next bitmask
        bitmask--;
        while (bitmask > 0 && Integer.bitCount(bitmask) != k) {
            bitmask--;
        }

        return sb.toString();
    }

    public boolean hasNext() {
        return bitmask > 0;
    }
}

Minimum Number of Lines to Cover Points

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
public int minimumLines(int[][] points) {
    int n = points.length;
    if (n < 3) {
        return 1;
    }

    // {a#b : list of point index}
    Map<String, List<Integer>> map = new HashMap<>();

    // y = a * x + b
    double a = 0, b = 0;
    for (int i = 0; i < n; i++) {
        int x1 = points[i][0], y1 = points[i][1];
        for (int j = i + 1; j < n; j++) {
            int x2 = points[j][0], y2 = points[j][1];
            if (x1 == x2) {
                a = x1;
                b = Double.NaN;
            } else {
                a = (double)(y2 - y1) / (x2 - x1);
                b = y1 - a * x1;
            }

            String key = a + "#" + b;
            map.computeIfAbsent(key, k -> new ArrayList<>()).add(i);
            map.computeIfAbsent(key, k -> new ArrayList<>()).add(j);
        }
    }

    // filters out lines that cover >= 3 points
    List<String> lines = map.entrySet().stream()
        .filter(e -> e.getValue().size() > 2)
        .map(e -> e.getKey())
        .collect(Collectors.toList());

    // max possible answer is Math.ceil(n / 2d)
    // m is the number of (a, b) pairs with >= 3 points
    int min = (n + 1) / 2, m = lines.size();
    Set<Integer> set = new HashSet<>();

    for (int mask = 1; mask < (1 << m); mask++) {
        int i = mask, j = 0, count = Integer.bitCount(mask);
        set.clear();

        // computes the number of points in each combination of `lines`
        while (i > 0) {
            if ((i & 1) == 1) {
                map.get(lines.get(j)).forEach(set::add);
            }
            i >>= 1;
            j++;
        }

        min = Math.min(min, count + (int)Math.ceil((n - set.size()) / 2d));
    }
    return min;
}

Partition Array Into Two Arrays to Minimize Sum Difference

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
public int minimumDifference(int[] nums) {
    int n = nums.length / 2;
    List<Integer>[] leftSums = subsetSums(nums, 0), rightSums = subsetSums(nums, n);

    // p1: sum of Partition 1
    // p2: sum of Partition 2
    //
    // d = Math.abs(p1 - p2)
    //   = Math.abs(p1 - (sum - p1))
    //   = Math.abs(sum - 2 * p1)
    //
    // To make d == 0, p1 = sum / 2
    int min = Integer.MAX_VALUE, sum = Arrays.stream(nums).sum();
    for (int i = 0; i <= n; i++) {
        for (int l : leftSums[i]) {
            // l: sum of p1 elements that are on the left half of nums
            // r: sum of p1 elements that are on the right half of nums
            // p1 = l + r
            // r = p1 - l
            //   = sum / 2 - l
            int r = binarySearch(rightSums[n - i], sum / 2 - l);

            min = Math.min(min, Math.abs(sum - 2 * (l + r)));
        }
    }
    return min;
}

private List<Integer>[] subsetSums(int[] nums, int start) {
    int n = nums.length / 2;

    // sums[i]: list of sums with bitCount == i
    List<Integer>[] sums = new List[n + 1];
    for (int i = 0; i <= n; i++) {
        sums[i] = new ArrayList<>();
    }

    // Computes the subset sum of each state
    for (int s = 0; s < (1 << n); s++) {
        int sum = 0;
        for (int i = 0; i < n; i++) {
            if ((s & (1 << i)) > 0) {
                sum += nums[start + i];
            }
        }
        sums[Integer.bitCount(s)].add(sum);
    }

    Arrays.stream(sums).forEach(Collections::sort);
    return sums;
}

private int binarySearch(List<Integer> list, int target) {
    // Searches for the element closest to target.
    int low = 0, high = list.size() - 1;
    while (low < high) {
        int mid = (low + high) >>> 1;
        if (list.get(mid + 1) - target >= target - list.get(mid)) {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    return list.get(low);
}

Number of Ways to Build Sturdy Brick Wall

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
private static final int MOD = (int)1e9 + 7;
private static final int MAX_BRICK_LENGTH = 10;

public int buildWall(int height, int width, int[] bricks) {
    // finds all possible split ways to build one single layer
    List<Integer> ways = new ArrayList<>();
    backtrack(width, bricks, ways, 0, 0);

    int m = ways.size();
    if (height == 1) {
        return m;
    }

    // graph[i]: a list of indexes of in ways, which represent the possible adjacent splits of ways[i]
    List<List<Integer>> graph = ways.stream()
        .map(split ->
             IntStream.range(0, m)
                // adjacent rows should not join bricks at the same location
                .filter(i -> (split & ways.get(i)) == 0)
                .boxed()
                .collect(Collectors.toList()))
        .collect(Collectors.toList());

    // builds the first layer DP
    // dp[i]: number of ways to build this layer with the mask ways[i]
    int[] dp = new int[m];
    Arrays.fill(dp, 1);

    // builds the wall layer by layer
    while (--height > 0) {
        int[] next = new int[m];
        for (int i = 0; i < m; i++) {
            for (int split : graph.get(i)) {
                next[split] = (next[split] + dp[i]) % MOD;
            }
        }
        dp = next;
    }

    int count = 0;
    for (int num : dp) {
        count = (count + num) % MOD;
    }
    return count;
}

private void backtrack(int width, int[] bricks, List<Integer> ways, int mask, int currWidth) {
    // uses bit mask to represent splits
    // e.g. ('|' stands for splits)
    // [1,2,3]: |-|--- 101000
    // [3,1,2]: --||-- 001100
    // [2,3,1]: -|--|- 010010
    for (int brick : bricks) {
        if (currWidth + brick == width) {
            ways.add(mask);
        } else if (currWidth + brick < width) {
            backtrack(width, bricks, ways, mask | (1 << (currWidth + brick)), currWidth + brick);
        }
    }
}

Set Cover Problem

Set cover problem: NP-complete

Smallest Sufficient Team

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public int[] smallestSufficientTeam(String[] req_skills, List<List<String>> people) {
    Map<String, Integer> map = new HashMap<>();
    int index = 0;
    for (String s : req_skills) {
        map.put(s, index++);
    }

    int n = req_skills.length, m = 1 << n;

    // dp[i]: the sufficient team of skill set i
    List<Integer>[] dp = new List[m];
    dp[0] = new ArrayList<>();

    for (int i = 0; i < people.size(); i++) {
        // current skill set for the i-th person
        int skills = 0;
        for (String p : people.get(i)) {
            if (map.containsKey(p)) {
                skills |= 1 << map.get(p);
            }
        }

        // updates states with this new person
        // j-th bit in mask == 1 means the j-th skill is required
        for (int state = 0; state < m; state++) {
            // skips if the state is not processed yet
            if (dp[state] == null) {
                continue;
            }

            int newSkillSet = skills | state;
            // skips if the state is a subset of current skills
            // meaning adding this person doesn't make a difference
            if (newSkillSet == state) {
                continue;
            }

            // if newSkillSet is not covered, or the newSkillSet has more teams than the current teams + 1
            if (dp[newSkillSet] == null || dp[newSkillSet].size() > dp[state].size() + 1) {
                dp[newSkillSet] = new ArrayList<>(dp[state]);
                dp[newSkillSet].add(i);
            }
        }
    }
    return dp[m - 1].stream().mapToInt(i -> i).toArray();
}

Subsets of Mask

Distribute Repeating Integers

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
public boolean canDistribute(int[] nums, int[] quantity) {
    // frequency map
    List<Long> freq = new ArrayList<>(Arrays.stream(nums).boxed()
        .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))
        .values());

    // preprocessing
    Collections.sort(freq);
    Arrays.sort(quantity);
    List<Integer> q = Arrays.stream(quantity).boxed().collect(Collectors.toList());

    List<Long> freqList = new ArrayList<>();
    List<Integer> quantityList = new ArrayList<>();
    int i = 0;
    for (long f : freq) {
        // quantity is all iterated
        if (i == q.size()) {
            freqList.add(f);
            continue;
        }

        if (f < q.get(i)) {
            continue;
        }

        // puts the ball into the box and discards them
        if (i == q.size() - 1 || f == q.get(i)) {
            i++;
            continue;
        }

        // if the box can contain at most one ball,
        // finds the largest ball that fits into the box
        if (f < q.get(i) + q.get(i + 1)) {
            int index = Collections.binarySearch(q, (int)f);
            if (index < 0) {
                index = ~index;
                q.remove(index - 1);
            } else {
                q.remove(index);
            }
            continue;
        }

        // to be determined
        freqList.add(f);
        quantityList.add(q.get(i++));
    }

    while (i < q.size()) {
        quantityList.add(q.get(i++));
    }

    // all customers are satisfied
    if (quantityList.isEmpty()) {
        return true;
    }

    // no integers available for the remaining customers
    if (freqList.isEmpty()) {
        return false;
    }

    int m = quantityList.size();
    return dfs(0, (1 << m) - 1, freqList, quantityList, new Boolean[freqList.size()][1 << m]);
}

// mask is the customers who is not assigned with integers
private boolean dfs(int index, int mask, List<Long> freq, List<Integer> quantity, Boolean[][] memo) {
    // all customers are satisfied
    if (mask == 0) {
        return true;
    }

    if (index == freq.size()) {
        return false;
    }

    if (memo[index][mask] != null) {
        return memo[index][mask];
    }

    // all subsets of mask
    for (int state = mask; ; state = (state - 1) & mask) {
        int sum = 0;
        for (int i = 0; i < quantity.size(); i++) {
            // i-th customer is not assigned
            if (((1 << i) & state) > 0) {
                sum += quantity.get(i);
            }
        }

        // assigns the freq to the customers in this subset
        // mask ^ state is the new mask
        // if state == 0, we are skipping the integer at this index
        if (sum <= freq.get(index) && dfs(index + 1, mask ^ state, freq, quantity, memo)) {
            return memo[index][mask] = true;
        }

        // stops early
        if (state == 0) {
            break;
        }
    }
    return memo[index][mask] = false;
}

Minimum Cost to Connect Two Groups of Points

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
private List<List<Integer>> cost;
private int[][] memo;

public int connectTwoGroups(List<List<Integer>> cost) {
    this.cost = cost;

    int n1 = cost.size(), n2 = cost.get(0).size();
    this.memo = new int[n1 + 1][1 << n2];

    // rightMinCost[j]: min cost among all possible connections between group1(i) to group2(j)
    int[] rightMinCost = new int[n2];
    Arrays.fill(rightMinCost, Integer.MAX_VALUE);
    for (int j = 0; j < n2; j++) {
        for (int i = 0; i < n1; i++) {
            rightMinCost[j] = Math.min(rightMinCost[j], cost.get(i).get(j));
        }
    }
    return dfs(0, rightMinCost, 0);
}

private int dfs(int i, int[] rightMinCost, int mask) {
    if (memo[i][mask] > 0) {
        return memo[i][mask];
    }

    int min = 0;
    if (i < cost.size()) {
        // connects group1[i] and group2[j]
        min = Integer.MAX_VALUE;
        for (int j = 0; j < cost.get(0).size(); j++) {
            min = Math.min(min, cost.get(i).get(j) + dfs(i + 1, rightMinCost, mask | (1 << j)));
        }
    } else {
        // all points in group1 are connected
        // the unmatched groups can be connected using the precomputed min cost array
        for (int j = 0; j < cost.get(0).size(); j++) {
            if ((mask & (1 << j)) == 0) {
                min += rightMinCost[j];
            }
        }
    }
    return memo[i][mask] = min;
}

Minimum Time to Kill All Monsters

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public long minimumTime(int[] power) {
    int n = power.length;
    long[] dp = new long[1 << n];
    Arrays.fill(dp, Long.MAX_VALUE);

    // kills the first monster
    // now gain == 2
    for (int i = 0; i < n; i++) {
        dp[1 << i] = power[i];
    } 

    long[] dp2 = new long[1 << n];
    Arrays.fill(dp2, Long.MAX_VALUE);

    for (int gain = 2; gain <= n; gain++) {
        for (int j = 0; j < (1 << n); j++) {
            if (dp[j] != Long.MAX_VALUE) {
                for (int k = 0; k < n; k++) {
                    int curr = (1 << k);
                    if ((j & curr) == 0) {
                        // in the period of killing (k - 1, k]-th monsters
                        // gain * days >= power[k]
                        // days >= power[k] / gain
                        dp2[j | curr] = Math.min(dp2[j | curr], dp[j] + (int)Math.ceil(power[k] / (double)gain));
                    }
                }
            }
        }

        long[] tmp = dp;
        dp = dp2;
        dp2 = tmp;
    }
    return dp[(1 << n) - 1];
}

Maximum Students Taking Exam

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public int maxStudents(char[][] seats) {
    int m = seats.length, n = seats[0].length;
    int[] rowMasks = new int[m];
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            rowMasks[i] = (rowMasks[i] << 1) + (seats[i][j] == '.' ? 1 : 0);
        }
    }

    int stateSize = 1 << n, max = 0;
    int[][] dp = new int[m][stateSize];
    for (int i = 0; i < m; i++) {
        Arrays.fill(dp[i], -1);
    }

    for (int i = 0; i < m; i++) {
        // iterates through all masks
        for (int mask = 0; mask < stateSize; mask++) {
            // mask is a subset of valid, and
            // there are no adjancent students in the row
            if (((mask & rowMasks[i]) == mask) && ((mask & (mask >> 1)) == 0)) {
                if (i == 0) {
                    dp[i][mask] = Integer.bitCount(mask);
                } else {
                    for (int prev = 0; prev < stateSize; prev++) {
                        // (alway right shifts)
                        // no students in the upper left positions, and
                        // no students in the upper right positions, and
                        // the previous state is valid
                        if ((mask & (prev >> 1)) == 0 && ((mask >> 1) & prev) == 0 && dp[i - 1][prev] != -1)  {
                            dp[i][mask] = Math.max(dp[i][mask], dp[i - 1][prev] + Integer.bitCount(mask));
                        }
                    }
                }
                max = Math.max(max, dp[i][mask]);
            }
        }
    }
    return max;
}

Partitioning

Partition a dataset into groups.

Generally, each mask represents a state where k groups are complete and at most one group is incomplete. If there exists one incomplete group, then the new bit of the next mask will land in this group; otherwise, the new bit of the next mask will start a new group.

Iteration

Minimum Incompatibility

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
public int minimumIncompatibility(int[] nums, int k) {
    int n = nums.length, groupSize = n / k;
    if (groupSize == 1) {
        return 0;
    }

    // ensures inner-group elements are sorted
    Arrays.sort(nums);

    // dp[mask][i]: denotes b = Integer.bitCount(mask), r = b % groupSize, m = b / groupSize
    //   - if r == 0, there are m complete arrays, and nums[i] is the last number of the new array.
    //     dp[mask][i] is the min sum of incompatibilities of the m arrays where i is the index of the last picked number
    //   - otherwise, there are m complete arrays and one incomplete array, and nums[i] is the last number of the incomplete array
    //     dp[mask][i] is the min sum of incompatibilities of the m arrays, minus the minimum number of the incomplete array, where i is the index of the last picked number
    //    
    // dp[mask][n]: min of dp[mask][0...(n - 1)]
    int[][] dp = new int[1 << n][n + 1];
    int max = nums[n - 1] * n;
    for (int i = 0; i < dp.length; i++) {
        Arrays.fill(dp[i], max);
    }
    dp[0][n] = 0;

    for (int mask = 1; mask < (1 << n); mask++) {
        int bitCount = Integer.bitCount(mask), min = max;

        for (int i = 0; i < n; i++) {
            // picks set bits only so as to find the previous state
            if ((mask & (1 << i)) == 0) {
                continue;
            }

            int prev = mask ^ (1 << i);

            if (bitCount % groupSize == 1) {
                // nums[i] is the first and minimum number of the new group
                // finds the min sum of incompatibilies of prev among all of its possible last elements
                // since prev has complete groups only, any of its element could be the last element
                // so the min sum of incompatibilites of prev is simply dp[prev][n]
                dp[mask][i] = dp[prev][n] - nums[i];
            } else {
                // the last element of prev (i.e. nums[j]) and nums[i] are in the same incomplete group
                // finds the min sum of incompatibilies of prev among all of its possible last elements
                // nums is sorted, so only the elements with index j < i and nums[j] != nums[i] (i.e. nums[j] < nums[i]) could be the last elements
                for (int j = 0; j < i && nums[j] != nums[i]; j++) {
                    min = Math.min(min, dp[prev][j]);
                }

                // nums[i] is the last number of the new group
                if (bitCount % groupSize == 0) {
                    dp[mask][i] = min + nums[i];
                }

                if (bitCount % groupSize > 1) {
                    dp[mask][i] = min;
                }
            }

            dp[mask][n] = Math.min(dp[mask][n], dp[mask][i]);
        }
    }

    return dp[(1 << n) - 1][n] == max ? -1 : dp[(1 << n) - 1][n];
}

DFS

Each DFS level iterates the mask bits and tentatively adds each unused bit.

Partition to K Equal Sum Subsets

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
private boolean[] visited;

public boolean canPartitionKSubsets(int[] nums, int k) {
    int sum = 0, max = 0;
    for (int num : nums) {
        sum += num;
        max = Math.max(max, num);
    }

    int target = sum / k;
    if (sum % k != 0 || max > target) {
        return false;
    }

    // searches in reverse order, so that subset sizes decrease faster
    Arrays.sort(nums);

    this.visited = new boolean[1 << nums.length];
    return dfs(nums, 0, 0, target);
}

private boolean dfs(int[] nums, int mask, int subsetSum, int target) {
    int n = nums.length;
    if (mask == (1 << n) - 1) {
        return true;
    }

    if (visited[mask]) {
        return false;
    }

    for (int i = 0; i < n; i++) {
        if ((mask & (1 << i)) == 0) {
            // assigns i-th bit
            int next = mask | (1 << i);
            if (nums[i] + subsetSum <= target && dfs(nums, next, (subsetSum + nums[i]) % target, target)) {
                return true;
            }
        }
    }

    visited[mask] = true;
    return false;
}

Minimum Number of Work Sessions to Finish the Tasks

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private int sessionTime;
// memo[mask][remaining time in the current session]
private Integer[][] memo = new Integer[1 << 14][15];

public int minSessions(int[] tasks, int sessionTime) {
    this.sessionTime = sessionTime;
    return dfs(tasks, 0, 0);
}

private int dfs(int[] tasks, int mask, int remainTime) {
    int n = tasks.length;
    if (mask == (1 << n) - 1) {
        return 0;
    }

    if (memo[mask][remainTime] != null) {
        return memo[mask][remainTime];
    }

    int min = n;
    for (int i = 0; i < n; i++) {
        if ((mask & (1 << i)) == 0) {
            // assigns i-th bit
            int next = mask | (1 << i);
            if (tasks[i] <= remainTime) {
                min = Math.min(min, dfs(tasks, next, remainTime - tasks[i]));
            } else {
                min = Math.min(min, dfs(tasks, next, sessionTime - tasks[i]) + 1);
            }
        }
    }
    return memo[mask][remainTime] = min;
}

Maximize Score After N Operations

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private int[][] memo;

public int maxScore(int[] nums) {
    this.memo = new int[nums.length / 2 + 1][1 << nums.length];
    return dfs(nums, 1, 0);
}

private int dfs(int[] nums, int k, int mask) {
    int n = nums.length / 2;
    if (k > n) {
        return 0;
    }

    if (memo[k][mask] > 0) {
        return memo[k][mask];
    }

    int max = 0;
    for (int i = 0; i < 2 * n; i++) {
        for (int j = i + 1; j < 2 * n; j++) {
            int pair = (1 << i) + (1 << j);
            if ((mask & pair) == 0) {
                max = Math.max(max, k * gcd(nums[i], nums[j]) + dfs(nums, k + 1, mask | pair));
            }
        }
    }
    return memo[k][mask] = max;
}

The Number of Good Subsets

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
private static final int MOD = (int)1e9 + 7, MAX = 30;
private static final int[] PRIMES = new int[]{2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
// a bad number has a least one prime factor that appears at least twice
private static final int[] BADS = new int[]{4, 8, 12, 16, 20, 24, 28, 9, 18, 27, 25};
private static final Set<Integer> BAD_SET = new HashSet<>();
private static int MASKS[] = new int[MAX + 1];
static {
    Arrays.stream(BADS).forEach(BAD_SET::add);

    // MASKS[num]: mask representing the prime factors of num
    for (int i = 0; i < PRIMES.length; i++) {
        for (int num = 1; num <= MAX; num++) {
            if (num % PRIMES[i] == 0) {
                MASKS[num] |= 1 << i;
            }
        }
    }
}

private int[] freqs = new int[MAX + 1];
// memo[mask][num]: number of good subsets if the initial prime set is represented by `mask`,
//   and the candidate numbers are >= `num`
//   although empty set is not a good subset, it's included to simplify the computation
private Long[][] memo = new Long[1 << PRIMES.length][MAX + 1];

public int numberOfGoodSubsets(int[] nums) {
    for (int num : nums) {
        freqs[num]++;
    }

    // computes `2 ^ freq[1]` with modulo
    // e.g. [1, 1, 2]: (2 ^ 2) * 1
    int ones = 1;
    for (int i = 0; i < freqs[1]; i++) {
        ones *= 2;
        ones %= MOD;
    }

    // -1 because of the empty set
    return (int)(((dfs(0, 2) - 1 + MOD) % MOD * ones) % MOD);
}

private long dfs(int mask, int num) {
    if (num == MAX + 1) {
        return 1;
    }

    if (memo[mask][num] != null) {
        return memo[mask][num];
    }

    // doesn't pick num
    long count = dfs(mask, num + 1);

    // picks num
    // if num doesn't have a common prime factor with the current mask
    if (!BAD_SET.contains(num) && (mask & MASKS[num]) == 0) {
        count += dfs(mask | MASKS[num], num + 1) * freqs[num];
    }
    return memo[mask][num] = count % MOD;
}

Minimum XOR Sum of Two Arrays

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private int[] memo;

public int minimumXORSum(int[] nums1, int[] nums2) {
    this.memo = new int[1 << nums2.length];
    Arrays.fill(memo, Integer.MAX_VALUE);

    return dfs(nums1, nums2, 0, 0);
}

private int dfs(int[] nums1, int[] nums2, int i, int mask) {
    if (i == nums1.length) {
        return 0;
    }

    // for each position i in nums1,
    // tries all elements in nums2 that haven't been chosen before
    if (memo[mask] == Integer.MAX_VALUE) {
        for (int j = 0; j < nums2.length; j++) {
            // j-th element in nums2 is not used
            if ((mask & (1 << j)) == 0) {
                memo[mask] = Math.min(memo[mask], (nums1[i] ^ nums2[j]) + dfs(nums1, nums2, i + 1, mask | (1 << j)));
            }
        }
    }
    return memo[mask];
}

Number of Ways to Wear Different Hats to Each Other

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
private static final int MOD = (int)1e9 + 7;
private static final int NUM_HATS = 40;

public int numberWays(List<List<Integer>> hats) {
    int n = hats.size();
    // h2p[i]: list of people who prefers wearing i-th hat
    List<Integer>[] h2p = new List[NUM_HATS + 1];
    for (int i = 1; i < h2p.length; i++) {
        h2p[i] = new ArrayList<>();
    }

    for (int i = 0; i < n; i++) {
        for (int h : hats.get(i)) {
            h2p[h].add(i);
        }
    }

    return dfs(h2p, n, 1, 0, new Integer[NUM_HATS + 1][1 << n]);
}

/**
 * Finds number of ways that the people mask assigned with hats[hat...].
 * @param hat current hat label
 * @param mask people that wear hats
 * @param memo memo[i][j] number of ways that the people mask j assigned with hats[i...]
 */
int dfs(List<Integer>[] h2p, int n, int hat, int mask, Integer[][] memo) {
    // all people wear hats
    if (mask == (1 << n) - 1) {
        return 1;
    }

    // no more hats to process
    if (hat > NUM_HATS) {
        return 0;
    }

    if (memo[hat][mask] != null) {
        return memo[hat][mask];
    }

    // no one wears this hat
    int count = dfs(h2p, n, hat + 1, mask, memo);

    for (int p : h2p[hat]) {
        // skips if this person already wears a hat
        if ((mask & (1 << p)) > 0) {
            continue;
        }

        // this person wears this hat
        count = (count + dfs(h2p, n, hat + 1, mask | (1 << p), memo)) % MOD;
    }
    return memo[hat][mask] = count;
}

Permutation

The idea is similar to Partitioning problems: each DFS level iterates the mask bits and tentatively adds each unused bit.

Special Permutations

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private static final int MOD = (int)1e9 + 7;
private Integer[][] memo;

public int specialPerm(int[] nums) {
    int n = nums.length;
    memo = new Integer[1 << n][n + 1];
    return dfs(nums, 0, n);
}

private int dfs(int[] nums, int mask, int lastIndex) {
    int n = nums.length;
    if (mask == (1 << n) - 1) {
        return 1;
    }

    if (memo[mask][lastIndex] != null) {
        return memo[mask][lastIndex];
    }

    int perms = 0;
    for (int i = 0; i < n; i++) {
        if ((mask & (1 << i)) == 0) {
            // assigns i-th bit
            if (lastIndex == n || nums[i] % nums[lastIndex] == 0 || nums[lastIndex] % nums[i] == 0) {
                perms = (perms + dfs(nums, mask | (1 << i), i)) % MOD;
            }
        }
    }
    return memo[mask][lastIndex] = perms;
}

Multi-dimension

Maximize Grid Happiness

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
private Integer[][][][][] dp;

public int getMaxGridHappiness(int m, int n, int introvertsCount, int extrovertsCount) {
    int nIntro = Math.min(m * n, 6), nExtro = Math.min(m * n, 6);
    // iteration order: top -> bottom, left -> right (the order how p increases)
    // mask order: p with least value is msb, p with most value is lsb

    // positions in the grid (p = i * m + j)
    // remaining introverts
    // remaining extroverts
    // introverts mask of last n positions (could be wrapped from the last row)
    // Extroverts mask of last n positions (could be wrapped from the last row)
    this.dp = new Integer[m * n][nIntro + 1][nExtro + 1][1 << nIntro][1 << nExtro];
    return dfs(m, n, 0, introvertsCount, extrovertsCount, 0, 0);
}

// calculates the sum of happiness change caused by all the neighbors of (i, j)
// d is the change of happiness caused by the current cell, to each of its neighbor
private int calculateHappinessDiff(int m, int n, int i, int j, int inMask, int exMask, int d) {
    int diff = 0, upBit = 1 << (n - 1);
    // left neighbor is introvert
    if (j > 0 && (inMask & 1) != 0) {
        diff += d - 30;
    }

    // up neighbor is introvert
    if (i > 0 && (inMask & upBit) != 0) {
        diff += d - 30;
    }

    // left neighbor is extrovert
    if (j > 0 && (exMask & 1) != 0) {
        diff += d + 20;
    }

    // up neighbor is extrovert
    if (i > 0 && (exMask & upBit) != 0) {
        diff += d + 20;
    }
    return diff;
}

private int dfs(int m, int n, int p, int in, int ex, int inMask, int exMask) {
    int i = p / n, j = p % n;
    if (i >= m) {
        return 0;
    }

    if (dp[p][in][ex][inMask][exMask] != null) {
        return dp[p][in][ex][inMask][exMask];
    }

    int nextInMask = (inMask << 1) & ((1 << n) - 1), nextExMask = (exMask << 1) & ((1 << n) - 1);
    // leaves the current cell empty
    int max = dfs(m, n, p + 1, in, ex, nextInMask, nextExMask);

    // introvert person lives in the current cell
    if (in > 0) {
        int diff = 120 + calculateHappinessDiff(m, n, i, j, inMask, exMask, -30);
        // +1 is the current cell bit
        max = Math.max(max, diff + dfs(m, n, p + 1, in - 1, ex, nextInMask + 1, nextExMask));
    }

    // extrovert person lives in the current cell
    if (ex > 0) {
        int diff = 40 + calculateHappinessDiff(m, n, i, j, inMask, exMask, 20);
        // +1 is the current cell bit
        max = Math.max(max, diff + dfs(m, n, p + 1, in, ex - 1, nextInMask, nextExMask + 1));
    }

    return dp[p][in][ex][inMask][exMask] = max;
}

Radix

Maximum AND Sum of Array

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public int maximumANDSum(int[] nums, int numSlots) {
    // base-3 bit mask
    int mask = (int)Math.pow(3, numSlots) - 1;
    return dp(nums, numSlots, 0, mask, new int[mask + 1]);
}

private int dp(int[] nums, int numSlots, int index, int mask, int[] memo) {
    if (index == nums.length) {
        return 0;
    }

    if (memo[mask] > 0) {
        return memo[mask];
    }

    int max = 0;
    for (int slot = 1, bit = 1; slot <= numSlots; slot++, bit *= 3) {
        if (mask / bit % 3 > 0) {
            max = Math.max(max, (nums[index] & slot) + dp(nums, numSlots, index + 1, mask - bit, memo));
        }
    }
    return memo[mask] = max;
}
This post is licensed under CC BY 4.0 by the author.