Bit Mask
Tricks
(x >> i) & 1
: get thei-th
bit in statex
(x & y) == x
: check ifx
is a subset ofy
(x & (x >> 1)) == 0
: check if there are no adjancent valid states inx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
private Map<Integer, Integer> map = new HashMap<>();
private Set<Integer> visited = new HashSet<>();
public int[] groupStrings(String[] words) {
int groups = 0, maxSize = 0;
for (String w : words) {
int mask = 0;
for (char ch : w.toCharArray()) {
mask |= 1 << (ch - 'a');
}
map.put(mask, map.getOrDefault(mask, 0) + 1);
}
for (int k : map.keySet()) {
int size = dfs(k);
maxSize = Math.max(maxSize, size);
if (size > 0) {
groups++;
}
}
return new int[]{groups, maxSize};
}
private int dfs(int mask) {
if (!map.containsKey(mask) || !visited.add(mask)) {
return 0;
}
int count = map.get(mask);
for (int i = 0; i < 26; i++) {
// add/delete: flips one bit
count += dfs(mask ^ (1 << i));
for (int j = i + 1; j < 26; j++) {
// replace: flips two bits with different values
if (((mask >> i) & 1) != ((mask >> j) & 1)) {
count += dfs(mask ^ (1 << i) ^ (1 << j));
}
}
}
return count;
}
Enumeration
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class CombinationIterator {
private int bitmask = 0, n = 0, k = 0;
private String characters;
public CombinationIterator(String characters, int combinationLength) {
this.characters = characters;
this.n = characters.length();
this.k = combinationLength;
// 1(k)0(n - k)
this.bitmask = (1 << n) - (1 << (n - k));
}
public String next() {
// converts bitmask into combination
// 111 --> "abc", 000 --> ""
// 110 --> "ab", 101 --> "ac", 011 --> "bc"
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) {
if ((bitmask & (1 << (n - i - 1))) != 0) {
sb.append(characters.charAt(i));
}
}
// generates next bitmask
bitmask--;
while (bitmask > 0 && Integer.bitCount(bitmask) != k) {
bitmask--;
}
return sb.toString();
}
public boolean hasNext() {
return bitmask > 0;
}
}
Minimum Number of Lines to Cover Points
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
public int minimumLines(int[][] points) {
int n = points.length;
if (n < 3) {
return 1;
}
// {a#b : list of point index}
Map<String, List<Integer>> map = new HashMap<>();
// y = a * x + b
double a = 0, b = 0;
for (int i = 0; i < n; i++) {
int x1 = points[i][0], y1 = points[i][1];
for (int j = i + 1; j < n; j++) {
int x2 = points[j][0], y2 = points[j][1];
if (x1 == x2) {
a = x1;
b = Double.NaN;
} else {
a = (double)(y2 - y1) / (x2 - x1);
b = y1 - a * x1;
}
String key = a + "#" + b;
map.computeIfAbsent(key, k -> new ArrayList<>()).add(i);
map.computeIfAbsent(key, k -> new ArrayList<>()).add(j);
}
}
// filters out lines that cover >= 3 points
List<String> lines = map.entrySet().stream()
.filter(e -> e.getValue().size() > 2)
.map(e -> e.getKey())
.collect(Collectors.toList());
// max possible answer is Math.ceil(n / 2d)
// m is the number of (a, b) pairs with >= 3 points
int min = (n + 1) / 2, m = lines.size();
Set<Integer> set = new HashSet<>();
for (int mask = 1; mask < (1 << m); mask++) {
int i = mask, j = 0, count = Integer.bitCount(mask);
set.clear();
// computes the number of points in each combination of `lines`
while (i > 0) {
if ((i & 1) == 1) {
map.get(lines.get(j)).forEach(set::add);
}
i >>= 1;
j++;
}
min = Math.min(min, count + (int)Math.ceil((n - set.size()) / 2d));
}
return min;
}
Partition Array Into Two Arrays to Minimize Sum Difference
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
public int minimumDifference(int[] nums) {
int n = nums.length / 2;
List<Integer>[] leftSums = subsetSums(nums, 0), rightSums = subsetSums(nums, n);
// p1: sum of Partition 1
// p2: sum of Partition 2
//
// d = Math.abs(p1 - p2)
// = Math.abs(p1 - (sum - p1))
// = Math.abs(sum - 2 * p1)
//
// To make d == 0, p1 = sum / 2
int min = Integer.MAX_VALUE, sum = Arrays.stream(nums).sum();
for (int i = 0; i <= n; i++) {
for (int l : leftSums[i]) {
// l: sum of p1 elements that are on the left half of nums
// r: sum of p1 elements that are on the right half of nums
// p1 = l + r
// r = p1 - l
// = sum / 2 - l
int r = binarySearch(rightSums[n - i], sum / 2 - l);
min = Math.min(min, Math.abs(sum - 2 * (l + r)));
}
}
return min;
}
private List<Integer>[] subsetSums(int[] nums, int start) {
int n = nums.length / 2;
// sums[i]: list of sums with bitCount == i
List<Integer>[] sums = new List[n + 1];
for (int i = 0; i <= n; i++) {
sums[i] = new ArrayList<>();
}
// Computes the subset sum of each state
for (int s = 0; s < (1 << n); s++) {
int sum = 0;
for (int i = 0; i < n; i++) {
if ((s & (1 << i)) > 0) {
sum += nums[start + i];
}
}
sums[Integer.bitCount(s)].add(sum);
}
Arrays.stream(sums).forEach(Collections::sort);
return sums;
}
private int binarySearch(List<Integer> list, int target) {
// Searches for the element closest to target.
int low = 0, high = list.size() - 1;
while (low < high) {
int mid = (low + high) >>> 1;
if (list.get(mid + 1) - target >= target - list.get(mid)) {
high = mid;
} else {
low = mid + 1;
}
}
return list.get(low);
}
Number of Ways to Build Sturdy Brick Wall
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
private static final int MOD = (int)1e9 + 7;
private static final int MAX_BRICK_LENGTH = 10;
public int buildWall(int height, int width, int[] bricks) {
// finds all possible split ways to build one single layer
List<Integer> ways = new ArrayList<>();
backtrack(width, bricks, ways, 0, 0);
int m = ways.size();
if (height == 1) {
return m;
}
// graph[i]: a list of indexes of in ways, which represent the possible adjacent splits of ways[i]
List<List<Integer>> graph = ways.stream()
.map(split ->
IntStream.range(0, m)
// adjacent rows should not join bricks at the same location
.filter(i -> (split & ways.get(i)) == 0)
.boxed()
.collect(Collectors.toList()))
.collect(Collectors.toList());
// builds the first layer DP
// dp[i]: number of ways to build this layer with the mask ways[i]
int[] dp = new int[m];
Arrays.fill(dp, 1);
// builds the wall layer by layer
while (--height > 0) {
int[] next = new int[m];
for (int i = 0; i < m; i++) {
for (int split : graph.get(i)) {
next[split] = (next[split] + dp[i]) % MOD;
}
}
dp = next;
}
int count = 0;
for (int num : dp) {
count = (count + num) % MOD;
}
return count;
}
private void backtrack(int width, int[] bricks, List<Integer> ways, int mask, int currWidth) {
// uses bit mask to represent splits
// e.g. ('|' stands for splits)
// [1,2,3]: |-|--- 101000
// [3,1,2]: --||-- 001100
// [2,3,1]: -|--|- 010010
for (int brick : bricks) {
if (currWidth + brick == width) {
ways.add(mask);
} else if (currWidth + brick < width) {
backtrack(width, bricks, ways, mask | (1 << (currWidth + brick)), currWidth + brick);
}
}
}
Set Cover Problem
Set cover problem: NP-complete
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public int[] smallestSufficientTeam(String[] req_skills, List<List<String>> people) {
Map<String, Integer> map = new HashMap<>();
int index = 0;
for (String s : req_skills) {
map.put(s, index++);
}
int n = req_skills.length, m = 1 << n;
// dp[i]: the sufficient team of skill set i
List<Integer>[] dp = new List[m];
dp[0] = new ArrayList<>();
for (int i = 0; i < people.size(); i++) {
// current skill set for the i-th person
int skills = 0;
for (String p : people.get(i)) {
if (map.containsKey(p)) {
skills |= 1 << map.get(p);
}
}
// updates states with this new person
// j-th bit in mask == 1 means the j-th skill is required
for (int state = 0; state < m; state++) {
// skips if the state is not processed yet
if (dp[state] == null) {
continue;
}
int newSkillSet = skills | state;
// skips if the state is a subset of current skills
// meaning adding this person doesn't make a difference
if (newSkillSet == state) {
continue;
}
// if newSkillSet is not covered, or the newSkillSet has more teams than the current teams + 1
if (dp[newSkillSet] == null || dp[newSkillSet].size() > dp[state].size() + 1) {
dp[newSkillSet] = new ArrayList<>(dp[state]);
dp[newSkillSet].add(i);
}
}
}
return dp[m - 1].stream().mapToInt(i -> i).toArray();
}
Subsets of Mask
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
public boolean canDistribute(int[] nums, int[] quantity) {
// frequency map
List<Long> freq = new ArrayList<>(Arrays.stream(nums).boxed()
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))
.values());
// preprocessing
Collections.sort(freq);
Arrays.sort(quantity);
List<Integer> q = Arrays.stream(quantity).boxed().collect(Collectors.toList());
List<Long> freqList = new ArrayList<>();
List<Integer> quantityList = new ArrayList<>();
int i = 0;
for (long f : freq) {
// quantity is all iterated
if (i == q.size()) {
freqList.add(f);
continue;
}
if (f < q.get(i)) {
continue;
}
// puts the ball into the box and discards them
if (i == q.size() - 1 || f == q.get(i)) {
i++;
continue;
}
// if the box can contain at most one ball,
// finds the largest ball that fits into the box
if (f < q.get(i) + q.get(i + 1)) {
int index = Collections.binarySearch(q, (int)f);
if (index < 0) {
index = ~index;
q.remove(index - 1);
} else {
q.remove(index);
}
continue;
}
// to be determined
freqList.add(f);
quantityList.add(q.get(i++));
}
while (i < q.size()) {
quantityList.add(q.get(i++));
}
// all customers are satisfied
if (quantityList.isEmpty()) {
return true;
}
// no integers available for the remaining customers
if (freqList.isEmpty()) {
return false;
}
int m = quantityList.size();
return dfs(0, (1 << m) - 1, freqList, quantityList, new Boolean[freqList.size()][1 << m]);
}
// mask is the customers who is not assigned with integers
private boolean dfs(int index, int mask, List<Long> freq, List<Integer> quantity, Boolean[][] memo) {
// all customers are satisfied
if (mask == 0) {
return true;
}
if (index == freq.size()) {
return false;
}
if (memo[index][mask] != null) {
return memo[index][mask];
}
// all subsets of mask
for (int state = mask; ; state = (state - 1) & mask) {
int sum = 0;
for (int i = 0; i < quantity.size(); i++) {
// i-th customer is not assigned
if (((1 << i) & state) > 0) {
sum += quantity.get(i);
}
}
// assigns the freq to the customers in this subset
// mask ^ state is the new mask
// if state == 0, we are skipping the integer at this index
if (sum <= freq.get(index) && dfs(index + 1, mask ^ state, freq, quantity, memo)) {
return memo[index][mask] = true;
}
// stops early
if (state == 0) {
break;
}
}
return memo[index][mask] = false;
}
Minimum Cost to Connect Two Groups of Points
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
private List<List<Integer>> cost;
private int[][] memo;
public int connectTwoGroups(List<List<Integer>> cost) {
this.cost = cost;
int n1 = cost.size(), n2 = cost.get(0).size();
this.memo = new int[n1 + 1][1 << n2];
// rightMinCost[j]: min cost among all possible connections between group1(i) to group2(j)
int[] rightMinCost = new int[n2];
Arrays.fill(rightMinCost, Integer.MAX_VALUE);
for (int j = 0; j < n2; j++) {
for (int i = 0; i < n1; i++) {
rightMinCost[j] = Math.min(rightMinCost[j], cost.get(i).get(j));
}
}
return dfs(0, rightMinCost, 0);
}
private int dfs(int i, int[] rightMinCost, int mask) {
if (memo[i][mask] > 0) {
return memo[i][mask];
}
int min = 0;
if (i < cost.size()) {
// connects group1[i] and group2[j]
min = Integer.MAX_VALUE;
for (int j = 0; j < cost.get(0).size(); j++) {
min = Math.min(min, cost.get(i).get(j) + dfs(i + 1, rightMinCost, mask | (1 << j)));
}
} else {
// all points in group1 are connected
// the unmatched groups can be connected using the precomputed min cost array
for (int j = 0; j < cost.get(0).size(); j++) {
if ((mask & (1 << j)) == 0) {
min += rightMinCost[j];
}
}
}
return memo[i][mask] = min;
}
Minimum Time to Kill All Monsters
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public long minimumTime(int[] power) {
int n = power.length;
long[] dp = new long[1 << n];
Arrays.fill(dp, Long.MAX_VALUE);
// kills the first monster
// now gain == 2
for (int i = 0; i < n; i++) {
dp[1 << i] = power[i];
}
long[] dp2 = new long[1 << n];
Arrays.fill(dp2, Long.MAX_VALUE);
for (int gain = 2; gain <= n; gain++) {
for (int j = 0; j < (1 << n); j++) {
if (dp[j] != Long.MAX_VALUE) {
for (int k = 0; k < n; k++) {
int curr = (1 << k);
if ((j & curr) == 0) {
// in the period of killing (k - 1, k]-th monsters
// gain * days >= power[k]
// days >= power[k] / gain
dp2[j | curr] = Math.min(dp2[j | curr], dp[j] + (int)Math.ceil(power[k] / (double)gain));
}
}
}
}
long[] tmp = dp;
dp = dp2;
dp2 = tmp;
}
return dp[(1 << n) - 1];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public int maxStudents(char[][] seats) {
int m = seats.length, n = seats[0].length;
int[] rowMasks = new int[m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
rowMasks[i] = (rowMasks[i] << 1) + (seats[i][j] == '.' ? 1 : 0);
}
}
int stateSize = 1 << n, max = 0;
int[][] dp = new int[m][stateSize];
for (int i = 0; i < m; i++) {
Arrays.fill(dp[i], -1);
}
for (int i = 0; i < m; i++) {
// iterates through all masks
for (int mask = 0; mask < stateSize; mask++) {
// mask is a subset of valid, and
// there are no adjancent students in the row
if (((mask & rowMasks[i]) == mask) && ((mask & (mask >> 1)) == 0)) {
if (i == 0) {
dp[i][mask] = Integer.bitCount(mask);
} else {
for (int prev = 0; prev < stateSize; prev++) {
// (alway right shifts)
// no students in the upper left positions, and
// no students in the upper right positions, and
// the previous state is valid
if ((mask & (prev >> 1)) == 0 && ((mask >> 1) & prev) == 0 && dp[i - 1][prev] != -1) {
dp[i][mask] = Math.max(dp[i][mask], dp[i - 1][prev] + Integer.bitCount(mask));
}
}
}
max = Math.max(max, dp[i][mask]);
}
}
}
return max;
}
Partitioning
Partition a dataset into groups.
Generally, each mask represents a state where k
groups are complete and at most one group is incomplete. If there exists one incomplete group, then the new bit of the next mask will land in this group; otherwise, the new bit of the next mask will start a new group.
Iteration
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
public int minimumIncompatibility(int[] nums, int k) {
int n = nums.length, groupSize = n / k;
if (groupSize == 1) {
return 0;
}
// ensures inner-group elements are sorted
Arrays.sort(nums);
// dp[mask][i]: denotes b = Integer.bitCount(mask), r = b % groupSize, m = b / groupSize
// - if r == 0, there are m complete arrays, and nums[i] is the last number of the new array.
// dp[mask][i] is the min sum of incompatibilities of the m arrays where i is the index of the last picked number
// - otherwise, there are m complete arrays and one incomplete array, and nums[i] is the last number of the incomplete array
// dp[mask][i] is the min sum of incompatibilities of the m arrays, minus the minimum number of the incomplete array, where i is the index of the last picked number
//
// dp[mask][n]: min of dp[mask][0...(n - 1)]
int[][] dp = new int[1 << n][n + 1];
int max = nums[n - 1] * n;
for (int i = 0; i < dp.length; i++) {
Arrays.fill(dp[i], max);
}
dp[0][n] = 0;
for (int mask = 1; mask < (1 << n); mask++) {
int bitCount = Integer.bitCount(mask), min = max;
for (int i = 0; i < n; i++) {
// picks set bits only so as to find the previous state
if ((mask & (1 << i)) == 0) {
continue;
}
int prev = mask ^ (1 << i);
if (bitCount % groupSize == 1) {
// nums[i] is the first and minimum number of the new group
// finds the min sum of incompatibilies of prev among all of its possible last elements
// since prev has complete groups only, any of its element could be the last element
// so the min sum of incompatibilites of prev is simply dp[prev][n]
dp[mask][i] = dp[prev][n] - nums[i];
} else {
// the last element of prev (i.e. nums[j]) and nums[i] are in the same incomplete group
// finds the min sum of incompatibilies of prev among all of its possible last elements
// nums is sorted, so only the elements with index j < i and nums[j] != nums[i] (i.e. nums[j] < nums[i]) could be the last elements
for (int j = 0; j < i && nums[j] != nums[i]; j++) {
min = Math.min(min, dp[prev][j]);
}
// nums[i] is the last number of the new group
if (bitCount % groupSize == 0) {
dp[mask][i] = min + nums[i];
}
if (bitCount % groupSize > 1) {
dp[mask][i] = min;
}
}
dp[mask][n] = Math.min(dp[mask][n], dp[mask][i]);
}
}
return dp[(1 << n) - 1][n] == max ? -1 : dp[(1 << n) - 1][n];
}
DFS
Each DFS level iterates the mask bits and tentatively adds each unused bit.
Partition to K Equal Sum Subsets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
private boolean[] visited;
public boolean canPartitionKSubsets(int[] nums, int k) {
int sum = 0, max = 0;
for (int num : nums) {
sum += num;
max = Math.max(max, num);
}
int target = sum / k;
if (sum % k != 0 || max > target) {
return false;
}
// searches in reverse order, so that subset sizes decrease faster
Arrays.sort(nums);
this.visited = new boolean[1 << nums.length];
return dfs(nums, 0, 0, target);
}
private boolean dfs(int[] nums, int mask, int subsetSum, int target) {
int n = nums.length;
if (mask == (1 << n) - 1) {
return true;
}
if (visited[mask]) {
return false;
}
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) == 0) {
// assigns i-th bit
int next = mask | (1 << i);
if (nums[i] + subsetSum <= target && dfs(nums, next, (subsetSum + nums[i]) % target, target)) {
return true;
}
}
}
visited[mask] = true;
return false;
}
Minimum Number of Work Sessions to Finish the Tasks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private int sessionTime;
// memo[mask][remaining time in the current session]
private Integer[][] memo = new Integer[1 << 14][15];
public int minSessions(int[] tasks, int sessionTime) {
this.sessionTime = sessionTime;
return dfs(tasks, 0, 0);
}
private int dfs(int[] tasks, int mask, int remainTime) {
int n = tasks.length;
if (mask == (1 << n) - 1) {
return 0;
}
if (memo[mask][remainTime] != null) {
return memo[mask][remainTime];
}
int min = n;
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) == 0) {
// assigns i-th bit
int next = mask | (1 << i);
if (tasks[i] <= remainTime) {
min = Math.min(min, dfs(tasks, next, remainTime - tasks[i]));
} else {
min = Math.min(min, dfs(tasks, next, sessionTime - tasks[i]) + 1);
}
}
}
return memo[mask][remainTime] = min;
}
Maximize Score After N Operations
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private int[][] memo;
public int maxScore(int[] nums) {
this.memo = new int[nums.length / 2 + 1][1 << nums.length];
return dfs(nums, 1, 0);
}
private int dfs(int[] nums, int k, int mask) {
int n = nums.length / 2;
if (k > n) {
return 0;
}
if (memo[k][mask] > 0) {
return memo[k][mask];
}
int max = 0;
for (int i = 0; i < 2 * n; i++) {
for (int j = i + 1; j < 2 * n; j++) {
int pair = (1 << i) + (1 << j);
if ((mask & pair) == 0) {
max = Math.max(max, k * gcd(nums[i], nums[j]) + dfs(nums, k + 1, mask | pair));
}
}
}
return memo[k][mask] = max;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
private static final int MOD = (int)1e9 + 7, MAX = 30;
private static final int[] PRIMES = new int[]{2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
// a bad number has a least one prime factor that appears at least twice
private static final int[] BADS = new int[]{4, 8, 12, 16, 20, 24, 28, 9, 18, 27, 25};
private static final Set<Integer> BAD_SET = new HashSet<>();
private static int MASKS[] = new int[MAX + 1];
static {
Arrays.stream(BADS).forEach(BAD_SET::add);
// MASKS[num]: mask representing the prime factors of num
for (int i = 0; i < PRIMES.length; i++) {
for (int num = 1; num <= MAX; num++) {
if (num % PRIMES[i] == 0) {
MASKS[num] |= 1 << i;
}
}
}
}
private int[] freqs = new int[MAX + 1];
// memo[mask][num]: number of good subsets if the initial prime set is represented by `mask`,
// and the candidate numbers are >= `num`
// although empty set is not a good subset, it's included to simplify the computation
private Long[][] memo = new Long[1 << PRIMES.length][MAX + 1];
public int numberOfGoodSubsets(int[] nums) {
for (int num : nums) {
freqs[num]++;
}
// computes `2 ^ freq[1]` with modulo
// e.g. [1, 1, 2]: (2 ^ 2) * 1
int ones = 1;
for (int i = 0; i < freqs[1]; i++) {
ones *= 2;
ones %= MOD;
}
// -1 because of the empty set
return (int)(((dfs(0, 2) - 1 + MOD) % MOD * ones) % MOD);
}
private long dfs(int mask, int num) {
if (num == MAX + 1) {
return 1;
}
if (memo[mask][num] != null) {
return memo[mask][num];
}
// doesn't pick num
long count = dfs(mask, num + 1);
// picks num
// if num doesn't have a common prime factor with the current mask
if (!BAD_SET.contains(num) && (mask & MASKS[num]) == 0) {
count += dfs(mask | MASKS[num], num + 1) * freqs[num];
}
return memo[mask][num] = count % MOD;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private int[] memo;
public int minimumXORSum(int[] nums1, int[] nums2) {
this.memo = new int[1 << nums2.length];
Arrays.fill(memo, Integer.MAX_VALUE);
return dfs(nums1, nums2, 0, 0);
}
private int dfs(int[] nums1, int[] nums2, int i, int mask) {
if (i == nums1.length) {
return 0;
}
// for each position i in nums1,
// tries all elements in nums2 that haven't been chosen before
if (memo[mask] == Integer.MAX_VALUE) {
for (int j = 0; j < nums2.length; j++) {
// j-th element in nums2 is not used
if ((mask & (1 << j)) == 0) {
memo[mask] = Math.min(memo[mask], (nums1[i] ^ nums2[j]) + dfs(nums1, nums2, i + 1, mask | (1 << j)));
}
}
}
return memo[mask];
}
Number of Ways to Wear Different Hats to Each Other
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
private static final int MOD = (int)1e9 + 7;
private static final int NUM_HATS = 40;
public int numberWays(List<List<Integer>> hats) {
int n = hats.size();
// h2p[i]: list of people who prefers wearing i-th hat
List<Integer>[] h2p = new List[NUM_HATS + 1];
for (int i = 1; i < h2p.length; i++) {
h2p[i] = new ArrayList<>();
}
for (int i = 0; i < n; i++) {
for (int h : hats.get(i)) {
h2p[h].add(i);
}
}
return dfs(h2p, n, 1, 0, new Integer[NUM_HATS + 1][1 << n]);
}
/**
* Finds number of ways that the people mask assigned with hats[hat...].
* @param hat current hat label
* @param mask people that wear hats
* @param memo memo[i][j] number of ways that the people mask j assigned with hats[i...]
*/
int dfs(List<Integer>[] h2p, int n, int hat, int mask, Integer[][] memo) {
// all people wear hats
if (mask == (1 << n) - 1) {
return 1;
}
// no more hats to process
if (hat > NUM_HATS) {
return 0;
}
if (memo[hat][mask] != null) {
return memo[hat][mask];
}
// no one wears this hat
int count = dfs(h2p, n, hat + 1, mask, memo);
for (int p : h2p[hat]) {
// skips if this person already wears a hat
if ((mask & (1 << p)) > 0) {
continue;
}
// this person wears this hat
count = (count + dfs(h2p, n, hat + 1, mask | (1 << p), memo)) % MOD;
}
return memo[hat][mask] = count;
}
Permutation
The idea is similar to Partitioning problems: each DFS level iterates the mask bits and tentatively adds each unused bit.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private static final int MOD = (int)1e9 + 7;
private Integer[][] memo;
public int specialPerm(int[] nums) {
int n = nums.length;
memo = new Integer[1 << n][n + 1];
return dfs(nums, 0, n);
}
private int dfs(int[] nums, int mask, int lastIndex) {
int n = nums.length;
if (mask == (1 << n) - 1) {
return 1;
}
if (memo[mask][lastIndex] != null) {
return memo[mask][lastIndex];
}
int perms = 0;
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) == 0) {
// assigns i-th bit
if (lastIndex == n || nums[i] % nums[lastIndex] == 0 || nums[lastIndex] % nums[i] == 0) {
perms = (perms + dfs(nums, mask | (1 << i), i)) % MOD;
}
}
}
return memo[mask][lastIndex] = perms;
}
Multi-dimension
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
private Integer[][][][][] dp;
public int getMaxGridHappiness(int m, int n, int introvertsCount, int extrovertsCount) {
int nIntro = Math.min(m * n, 6), nExtro = Math.min(m * n, 6);
// iteration order: top -> bottom, left -> right (the order how p increases)
// mask order: p with least value is msb, p with most value is lsb
// positions in the grid (p = i * m + j)
// remaining introverts
// remaining extroverts
// introverts mask of last n positions (could be wrapped from the last row)
// Extroverts mask of last n positions (could be wrapped from the last row)
this.dp = new Integer[m * n][nIntro + 1][nExtro + 1][1 << nIntro][1 << nExtro];
return dfs(m, n, 0, introvertsCount, extrovertsCount, 0, 0);
}
// calculates the sum of happiness change caused by all the neighbors of (i, j)
// d is the change of happiness caused by the current cell, to each of its neighbor
private int calculateHappinessDiff(int m, int n, int i, int j, int inMask, int exMask, int d) {
int diff = 0, upBit = 1 << (n - 1);
// left neighbor is introvert
if (j > 0 && (inMask & 1) != 0) {
diff += d - 30;
}
// up neighbor is introvert
if (i > 0 && (inMask & upBit) != 0) {
diff += d - 30;
}
// left neighbor is extrovert
if (j > 0 && (exMask & 1) != 0) {
diff += d + 20;
}
// up neighbor is extrovert
if (i > 0 && (exMask & upBit) != 0) {
diff += d + 20;
}
return diff;
}
private int dfs(int m, int n, int p, int in, int ex, int inMask, int exMask) {
int i = p / n, j = p % n;
if (i >= m) {
return 0;
}
if (dp[p][in][ex][inMask][exMask] != null) {
return dp[p][in][ex][inMask][exMask];
}
int nextInMask = (inMask << 1) & ((1 << n) - 1), nextExMask = (exMask << 1) & ((1 << n) - 1);
// leaves the current cell empty
int max = dfs(m, n, p + 1, in, ex, nextInMask, nextExMask);
// introvert person lives in the current cell
if (in > 0) {
int diff = 120 + calculateHappinessDiff(m, n, i, j, inMask, exMask, -30);
// +1 is the current cell bit
max = Math.max(max, diff + dfs(m, n, p + 1, in - 1, ex, nextInMask + 1, nextExMask));
}
// extrovert person lives in the current cell
if (ex > 0) {
int diff = 40 + calculateHappinessDiff(m, n, i, j, inMask, exMask, 20);
// +1 is the current cell bit
max = Math.max(max, diff + dfs(m, n, p + 1, in, ex - 1, nextInMask, nextExMask + 1));
}
return dp[p][in][ex][inMask][exMask] = max;
}
Radix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public int maximumANDSum(int[] nums, int numSlots) {
// base-3 bit mask
int mask = (int)Math.pow(3, numSlots) - 1;
return dp(nums, numSlots, 0, mask, new int[mask + 1]);
}
private int dp(int[] nums, int numSlots, int index, int mask, int[] memo) {
if (index == nums.length) {
return 0;
}
if (memo[mask] > 0) {
return memo[mask];
}
int max = 0;
for (int slot = 1, bit = 1; slot <= numSlots; slot++, bit *= 3) {
if (mask / bit % 3 > 0) {
max = Math.max(max, (nums[index] & slot) + dp(nums, numSlots, index + 1, mask - bit, memo));
}
}
return memo[mask] = max;
}