Fundamentals
Disjoint-set data structure, also called union-find data structure, stores a collection of disjoint (non-overlapping) sets. Equivalently, it stores a partition of a set into disjoint subsets.
Time Complexity:
- find:
O(α(n))
- union:
O(α(n))
where α(n) is the extremely slow-growing inverse Ackermann function.
Redundant Connection
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| private int[] parents;
public int[] findRedundantConnection(int[][] edges) {
parent = new int[edges.length + 1];
for (int[] edge : edges) {
if (!union(edge[0], edge[1])) {
return edge;
}
}
return null;
}
private int find(int u) {
return parents[u] == 0 ? u : find(parents[u]);
}
private boolean union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu == pv) {
return false;
}
parents[pu] = pv;
return true;
}
|
Path compression
1
2
3
4
| private int find(int u) {
// path compression
return parents[u] == 0 ? u : (parents[u] = find(parents[u]));
}
|
Union by rank
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| void unionSets(int u, int v) {
int pu = find(u), pv = find(v);
if (pu != pv) {
if (ranks[pu] < ranks[pv]) {
parents[pu] = pv;
} else if (ranks[pu] > ranks[pv]) {
parents[pv] = pu;
} else {
parents[pu] = pv;
ranks[pv]++;
}
}
}
|
Graph Connectivity With Threshold
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| public List<Boolean> areConnected(int n, int threshold, int[][] queries) {
this.parents = new int[n + 1];
for (int z = threshold + 1; z < n; z++) {
// unions all multiples of z
for (int k = 2; k * z <= n; k++) {
union(z, k * z);
}
}
List<Boolean> answer = new ArrayList<>();
for (int[] q : queries) {
answer.add(find(q[0]) == find(q[1]));
}
return answer;
}
|
Depending on the specific problem, parents
elements can be initialized to 0
, 1
, null
, etc.
Checking Existence of Edge Length Limited Paths
Offline queries:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
| private int[] parents;
public boolean[] distanceLimitedPathsExist(int n, int[][] edgeList, int[][] queries) {
int m = queries.length;
Integer[] indices = new Integer[m];
for (int i = 0; i < m; i++) {
indices[i] = i;
}
// sorts queries index by limit
Arrays.sort(indices, Comparator.comparingInt(i -> queries[i][2]));
// sorts edgeList by distance
Arrays.sort(edgeList, Comparator.comparingInt(e -> e[2]));
// union-find
this.parents = new int[n];
Arrays.fill(parents, -1);
boolean[] answer = new boolean[m];
int i = 0, j = 0;
while (j < m) {
int[] q = queries[indices[j]];
// unions all nodes whose edge distance < q[2]
// when q is updated, the existing disjoint sets remain the same
// we just need to add new edges to the proper set
while (i < edgeList.length && edgeList[i][2] < q[2]) {
union(edgeList[i][0], edgeList[i][1]);
i++;
}
// true if q nodes are in the same set
if (find(q[0]) == find(q[1])) {
answer[indices[j]] = true;
}
j++;
}
return answer;
}
|
Redundant Connection II
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
| public int[] findRedundantDirectedConnection(int[][] edges) {
// there are two cases where the graph is invalid
// - a node has two parents
// - a cycle exists
// parent of each node
int[] parents = new int[edges.length + 1];
// there's at most one node which has more than one parent
// and this node has at most two parents
// the two edges from the parents to this node are the two candidates
// note the order of these two candidates matters
int[] candidate1 = null, candidate2 = null;
for (int[] e : edges) {
if (parents[e[1]] == 0) {
parents[e[1]] = e[0];
} else {
// there are two parents of e[1]
candidate2 = new int[] {e[0], e[1]};
candidate1 = new int[] {parents[e[1]], e[1]};
// sets the edge of candidate2 to 0
// so that later when we construct the graph
// candidate2 is skipped
e[1] = 0;
}
}
// now uses the parents array as the roots for union-find
Arrays.fill(parents, 0);
// adds edges to the graph in order
for (int[] e : edges) {
if (e[1] == 0) {
continue;
}
// found cycle
if (!union(parents, e[0], e[1])) {
// if there's no candidate edge, this edge is redundant
// else candiate1 is redundant
// because candidate2 is not in the graph yet
// in other words, if there's a cycle, candidate1 has precedence over other edges
return candidate1 == null ? e : candidate1;
}
}
// there's no cycle
// removes the candiate2
return candidate2;
}
private int find(int[] parents, int u) {
return parents[u] == 0 ? u : find(parents, parents[u]);
}
private boolean union(int[] parents, int u, int v) {
int pu = find(parents, u), pv = find(parents, v);
if (pu == pv) {
return false;
}
parents[pu] = pv;
return true;
}
|
Sometimes, the parents
sets shall be represented as a map:
Evaluate Division
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
| private Map<String, String> parents = new HashMap<>();
private Map<String, Double> ratios = new HashMap<>(); // parent / node
public double[] calcEquation(List<List<String>> equations, double[] values, List<List<String>> queries) {
for (int i = 0; i < values.length; i++) {
List<String> e = equations.get(i);
union(e.get(0), e.get(1), values[i]);
}
int m = queries.size();
double[] answers = new double[m];
for (int i = 0; i < m; i++) {
List<String> q = queries.get(i);
answers[i] = query(q.get(0), q.get(1));
}
return answers;
}
private String find(String u) {
if (!parents.containsKey(u) || parents.get(u).equals(u)) {
ratios.put(u, 1.0);
return u;
}
// path compression
String p = parents.get(u), gp = find(p);
parents.put(u, gp);
ratios.put(u, ratios.get(u) * ratios.get(p)); // gp = p * ratio(p) = u * ratio(u) * ratio(p)
return gp;
}
// u / v = value
private void union(String u, String v, double value) {
String pu = find(u), pv = find(v);
parents.put(pv, pu);
// ratio = pu / pv
// = u * ratios(u) / (v * ratios(v))
// = value * ratios(u) / ratios(v)
ratios.put(pv, value * ratios.get(u) / ratios.get(v));
}
private double query(String s1, String s2) {
if (!ratios.containsKey(s1) || !ratios.containsKey(s2)) {
return -1.0;
}
String p1 = find(s1), p2 = find(s2);
if (!p1.equals(p2)) {
return -1.0;
}
// s1 / s2 = p / ratios(s1) / (p / ratios(s2))
// = ratios(s2) / ratios(s1)
return ratios.get(s2) / ratios.get(s1);
}
|
Another solution is to BFS/DFS the weighted graph.
Checking Existence of Edge Length Limited Paths II
Online queries:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
| // snapshost[i]: for the i-th element, {snapId, val}
private List<int[]>[] snapshots;
private List<Integer> snapTime = new ArrayList<>();
private int snapId = 0;
public DistanceLimitedPathsExist(int n, int[][] edgeList) {
this.snapshots = new List[n];
for (int i = 0; i < n; i++) {
snapshots[i] = new ArrayList<>();
snapshots[i].add(new int[]{0, -1});
}
// sorts edgeList by distance
Arrays.sort(edgeList, Comparator.comparingInt(e -> e[2]));
// builds snapshots before any query
// the dis are sorted in ascending order
// so groups are growing
int dis = 0;
for (int[] e : edgeList) {
// every time distance is increased, it's a snapshot
if (e[2] > dis) {
snapTime.add(dis);
dis = e[2];
snap();
}
union(e[0], e[1], snapId);
}
snapTime.add(dis);
snap();
}
// id is the snapshot id
private int find(int u, int id) {
// no path compression
int p = get(u, id);
return p < 0 ? u : find(p, id);
}
// id is the snapshot id
private void union(int u, int v, int id) {
int pu = find(u, id), pv = find(v, id);
// no path compression
// because nodes other than u and v are not updated in this snapshot
// otherwise there may be too many updates in one snapshot and impact performance
if (pu != pv) {
set(pu, pv);
}
}
// 1146. Snapshot Array
private int get(int index, int id) {
int pos = Collections.binarySearch(snapshots[index], new int[]{id, 0}, Comparator.comparingInt(a -> a[0]));
if (pos < 0) {
pos = ~pos - 1;
}
return snapshots[index].get(pos)[1];
}
private void set(int index, int val) {
List<int[]> snapshot = snapshots[index];
int size = snapshot.size();
if (snapshot.get(size - 1)[0] == snapId) { // overwrite
snapshot.get(size - 1)[1] = val;
} else { // create
snapshot.add(new int[]{snapId, val});
}
}
// snapshot means potential updates on the parent of some nodes
private int snap() {
return snapId++;
}
public boolean query(int p, int q, int limit) {
// finds the first snapshot whose dis is strictly less than limit
int id = Collections.binarySearch(snapTime, limit - 1);
if (id < 0) {
// the values of snapshot 0 is 0 <= limit - 1
// so the insertion point can't be 0
// ~id - 1 >= 0
id = ~id - 1;
}
return find(p, id) == find(q, id);
}
|
Number of Connected Componenets
Number of Connected Components in an Undirected Graph
1
2
3
4
5
6
7
8
9
10
11
12
| public int countComponents(int n, int[][] edges) {
this.parent = new int[n];
Arrays.fill(this.parent, -1);
for (int[] e : edges) {
if (union(e[0], e[1])) {
n--;
}
}
return n;
}
|
Most Stones Removed with Same Row or Column
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| private Map<Integer, Integer> parent = new HashMap<>();
private int count = 0;
// counts the connected components
public int removeStones(int[][] stones) {
for (int[] s : stones) {
// ~ to distinguish r and c
union(s[0], ~s[1]);
}
return stones.length - count;
}
private int find(int u) {
// a new component
if (parent.putIfAbsent(u, u) == null) {
count++;
}
// if u is not root
if (u != parent.get(u)) {
parent.put(u, find(parent.get(u)));
}
return parent.get(u);
}
private void union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu != pv) {
parent.put(pu, pv);
count--;
}
}
|
Remove Max Number of Edges to Keep Graph Fully Traversable
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
| public int maxNumEdgesToRemove(int n, int[][] edges) {
// prioritizes Type 3
Arrays.sort(edges, Comparator.comparingInt(e -> -e[0]));
UnionFind alice = new UnionFind(n), bob = new UnionFind(n);
// count of added edges
int count = 0;
for (int[] e : edges) {
switch (e[0]) {
case 1:
if (alice.union(e[1], e[2])) {
count++;
}
break;
case 2:
if (bob.union(e[1], e[2])) {
count++;
}
break;
case 3:
// no short-circuit
if (alice.union(e[1], e[2]) | bob.union(e[1], e[2])) {
count++;
}
break;
}
}
return alice.isFullyConnected() && bob.isFullyConnected() ? edges.length - count : -1;
}
class UnionFind {
int[] parents;
int components;
public UnionFind(int n) {
this.parents = new int[n + 1];
Arrays.fill(parents, -1);
this.components = n;
}
private boolean union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu == pv) {
return false;
}
parents[pu] = pv;
components--;
return true;
}
private int find(int u) {
// path compression
int p = parents[u];
if (p < 0) {
return u;
}
return parents[u] = find(p);
}
private boolean isFullyConnected() {
return components == 1;
}
}
|
Rank Transform of a Matrix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
| public int[][] matrixRankTransform(int[][] matrix) {
int m = matrix.length, n = matrix[0].length;
// matrix[i][j] : disjoint set
Map<Integer, DisjointSet> disjointSets = new HashMap<>();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
int v = matrix[i][j];
disjointSets.putIfAbsent(v, new DisjointSet(m + n));
// unions its row and col to the group
disjointSets.get(v).union(i, j + m);
}
}
// matrix[i][j] : map of groups
// within each group, the members share the map key as the disjoint set root,
// which means they are connected
Map<Integer, Map<Integer, List<int[]>>> groups = new TreeMap<>();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
int v = matrix[i][j];
groups.computeIfAbsent(v, k -> new HashMap<>())
.computeIfAbsent(disjointSets.get(v).find(i), r -> new ArrayList<>())
.add(new int[]{i, j});
}
}
int[][] answer = new int[m][n];
int[] rowMax = new int[m], colMax = new int[n];
for (var v : groups.values()) {
// updates by connected cells with same value
for (var cells : v.values()) {
int rank = 1;
for (int[] c : cells) {
rank = Math.max(rank, Math.max(rowMax[c[0]], colMax[c[1]]) + 1);
}
for (int[] c : cells) {
answer[c[0]][c[1]] = rank;
rowMax[c[0]] = Math.max(rowMax[c[0]], answer[c[0]][c[1]]);
colMax[c[1]] = Math.max(colMax[c[1]], answer[c[0]][c[1]]);
}
}
}
return answer;
}
class DisjointSet {
int[] parent;
public DisjointSet(int n) {
parent = new int[n];
Arrays.fill(parent, -1);
}
public int find(int u) {
return parent[u] < 0 ? u : find(parent[u]);
}
public void union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu != pv) {
parent[pu] = pv;
}
}
}
|
Regions Cut By Slashes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
| private int[] parent;
private int regions, n;
// splits each square into 4 triangles
private enum Triangle {
TOP,
RIGHT,
BOTTOM,
LEFT
}
public int regionsBySlashes(String[] grid) {
this.n = grid.length;
this.regions = n * n * 4; // total number of triangles
this.parent = new int[n * n * 4];
for (int i = 0; i < regions; i++) {
parent[i] = i;
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
// vertical
if (i > 0) {
union(indexOf(i - 1, j, Triangle.BOTTOM), indexOf(i, j, Triangle.TOP));
}
// horizontal
if (j > 0) {
union(indexOf(i, j - 1, Triangle.RIGHT), indexOf(i, j, Triangle.LEFT));
}
char c = grid[i].charAt(j);
// '\\' or ' '
if (c != '/') {
union(indexOf(i, j, Triangle.TOP), indexOf(i, j, Triangle.RIGHT));
union(indexOf(i, j, Triangle.BOTTOM), indexOf(i, j, Triangle.LEFT));
}
// '/' or ' '
if (c != '\\') {
union(indexOf(i, j, Triangle.TOP), indexOf(i, j, Triangle.LEFT));
union(indexOf(i, j, Triangle.BOTTOM), indexOf(i, j, Triangle.RIGHT));
}
}
}
return regions;
}
private int find(int u) {
return parent[u] == u ? u : find(parent[u]);
}
private void union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu != pv) {
parent[pu] = pv;
regions--;
}
}
private int indexOf(int i, int j, Triangle t) {
return (i * n + j) * 4 + t.ordinal();
}
|
Size of Each Set
Minimize Malware Spread
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
| private int[] parents;
public int minMalwareSpread(int[][] graph, int[] initial) {
int n = graph.length;
this.parents = new int[n];
Arrays.fill(parents, -1);
}
private int find(int u) {
return parents[u] < 0 ? u : find(parents[u]);
}
private void union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu != pv) {
// -parents[i] is the size of set i
parents[pv] += parents[pu];
parents[pu] = pv;
}
}
|
Minimize Malware Spread II
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
| public int minMalwareSpread(int[][] graph, int[] initial) {
int n = graph.length;
this.parents = new int[n];
Arrays.fill(parents, -1);
Set<Integer> initialSet = Arrays.stream(initial).boxed().collect(Collectors.toSet());
// unions non-malware nodes
for (int i = 0; i < n; i++) {
if (initialSet.contains(i)) {
continue;
}
for (int j = i + 1; j < n; j++) {
if (!initialSet.contains(j) && graph[i][j] == 1) {
union(i, j);
}
}
}
// finds the infected nodes by each initial malware
Map<Integer, Set<Integer>> infected = new HashMap<>();
// singleSource[p]: group p were infected by a single initial malware
int[] numOfSources = new int[n];
for (int i : initial) {
for (int j = 0; j < n; j++) {
if (!initialSet.contains(j) && graph[i][j] == 1) {
int p = find(j);
infected.computeIfAbsent(i, k -> new HashSet<>()).add(p);
}
}
if (infected.containsKey(i)) {
for (int p : infected.get(i)) {
numOfSources[p]++;
}
}
}
int max = 0, index = -1;
for (int i : initial) {
if (infected.containsKey(i)) {
int count = 0;
for (var p : infected.get(i)) {
if (numOfSources[p] == 1) {
count -= parents[p];
}
}
if (count > max || (count == max && i < index)) {
max = count;
index = i;
}
}
}
return index >= 0 ? index : Arrays.stream(initial).min().getAsInt();
}
|
Disconnected Components
Process Restricted Friend Requests
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| public boolean[] friendRequests(int n, int[][] restrictions, int[][] requests) {
this.parents = new Integer[n];
int m = requests.length;
boolean[] result = new boolean[m];
for (int i = 0; i < m; i++) {
int px = find(requests[i][0]), py = find(requests[i][1]);
result[i] = true;
if (px != py) {
for (int[] r : restrictions) {
int rx = find(r[0]), ry = find(r[1]);
// connecting x and y is restricted
if ((px == rx && py == ry) || (px == ry && py == rx)) {
result[i] = false;
break;
}
}
}
// unions two friends if request is valid
if (result[i]) {
union(px, py);
}
}
return result;
}
|
Find All People With Secret
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| vector<int> findAllPeople(int n, vector<vector<int>>& meetings, int firstPerson) {
parents = vector<int>(n, -1);
ranks = vector<int>(n);
// Share the secret initially with firstPerson
unionSets(0, firstPerson);
// Sort meetings by time
ranges::sort(meetings, {}, [&](const vector<int>& m){ return m[2]; });
int m = meetings.size();
// Track people involved in meetings at the current time
unordered_set<int> people;
for (int i = 0; i < m; i++) {
// Share secrets during the meeting
unionSets(meetings[i][0], meetings[i][1]);
people.insert(meetings[i][0]);
people.insert(meetings[i][1]);
// Process at the end of each time frame
if (i == m - 1 || meetings[i][2] != meetings[i + 1][2]) {
for (int p : people) {
// Disconnect people not knowing the secret from Person 0
if (find(p) != find(0)) {
parents[p] = -1;
}
}
people.clear();
}
}
vector<int> v;
for (int i = 0; i < n; i++) {
if (find(i) == find(0)) {
v.push_back(i);
}
}
return v;
}
|
Number of Good Paths
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
| private int[] parents, vals;
// maxCounts[i]: {max, count of max} in the root i
private int[][] maxCounts;
public int numberOfGoodPaths(int[] vals, int[][] edges) {
int n = vals.length;
this.parents = new int[n];
this.maxCounts = new int[n][2];
for (int i = 0; i < n; i++) {
parents[i] = -1;
maxCounts[i][0] = vals[i];
maxCounts[i][1] = 1;
}
this.vals = vals;
// buils the tree with nodes in ascending value order
Arrays.sort(edges, Comparator.comparingInt(e -> Math.max(vals[e[0]], vals[e[1]])));
int count = n;
for (int[] e : edges) {
count += union(e[0], e[1]);
}
return count;
}
private int find(int u) {
return parents[u] < 0 ? u : (parents[u] = find(parents[u]));
}
private int union(int u, int v) {
int pu = find(u), pv = find(v);
if (pu == pv) {
return 0;
}
parents[pu] = pv;
int maxVal = Math.max(vals[u], vals[v]);
int cu = maxCounts[pu][0] == maxVal ? maxCounts[pu][1] : 0;
int cv = maxCounts[pv][0] == maxVal ? maxCounts[pv][1] : 0;
maxCounts[pv][0] = maxVal;
maxCounts[pv][1] = cu + cv;
// combinatorics multiplication principle
return cu * cv;
}
|