Union Find

Use CaseDescriptionExample
Connected ComponentsDetermine if two nodes are in the same connected component in a graph.In a social network, check if two users are in the same group.
Network ConnectivityMonitor the connectivity of a network as connections are added or removed.In a computer network, determine if two devices can communicate after a new connection is made.
Kruskal's Minimum Spanning TreeUse Union-Find to detect cycles while constructing a minimum spanning tree.In a graph of cities, find the minimum cost to connect all cities without any cycles.

Code Template

Initializing variables

  1. self.parent = [i for i in range(size)]: Creates a parent array where each element points to itself, indicating that each element is its own set initially.
  2. self.rank = [1] * size: Initializes a rank array, where each element starts with a rank of 1, used to optimize union operations by keeping track of tree depth.

Find function

  1. This method is designed to find the root or representative of the set containing element p.
  2. if self.parent[p] != p: This line checks if the element p is its own parent. If not, it means p is not the root of its set.
  3. self.parent[p] = self.find(self.parent[p]): If p is not the root, this line recursively calls find on p's parent. The result is then assigned back to self.parent[p], effectively flattening the structure (path compression). This means that all nodes in the path from p to its root will now point directly to the root, speeding up future queries.

Union function

  1. rootP = self.find(p) and rootQ = self.find(q): Uses the find method to determine the roots of the sets containing p and q.
  2. The root with the higher rank becomes the parent, maintaining a balanced tree structure. If ranks are equal, one root becomes the parent, and its rank is incremented.

Check connection

  1. Compares the roots of p and q. If the roots are the same, it means p and q are connected.
class UnionFind:
def __init__(self, size):
self.parent = [i for i in range(size)]
self.rank = [1] * size

Common mistakes

  1. Wrong Path Compression Implementation
# Wrong: no path compression
def find(self, p):
while self.parent[p] != p: # Just traverses up
p = self.parent[p]
return p
# Wrong: incorrect recursion
def find(self, p):
if self.parent[p] != p:
return self.find(self.parent[p]) # Doesn't compress
# Correct:
def find(self, p):
if self.parent[p] != p:
self.parent[p] = self.find(self.parent[p])
return self.parent[p]
  1. Wrong Union by Rank Implementation
# Wrong: not using ranks
def union(self, p, q):
rootP = self.find(p)
rootQ = self.find(q)
self.parent[rootQ] = rootP # Always merges to rootP
# Wrong: incorrect rank update
def union(self, p, q):
if self.rank[rootP] == self.rank[rootQ]:
self.rank[rootP] += 1
self.rank[rootQ] += 1 # Should only increment one
# Correct:
if self.rank[rootP] == self.rank[rootQ]:
self.parent[rootQ] = rootP
self.rank[rootP] += 1
  1. Wrong Connected Check
# Wrong: not using find
def connected(self, p, q):
return self.parent[p] == self.parent[q] # Should use find()
# Wrong: redundant finds
def connected(self, p, q):
rootP = self.find(p)
rootQ = self.find(q)
return self.find(rootP) == self.find(rootQ) # Unnecessary finds
# Correct:
def connected(self, p, q):
return self.find(p) == self.find(q)

Issues and consideration

  • Thread Safety: If the class is intended for concurrent use, consider implementing locking mechanisms to prevent race conditions.

200. Number of Islands

class UnionFind:
def __init__(self, size):
self.parent = [i for i in range(size)]
self.rank = [1] * size
def find(self, p):
if self.parent[p] != p:
self.parent[p] = self.find(self.parent[p])
return self.parent[p]
def union(self, p, q):
rootP = self.find(p)
rootQ = self.find(q)
if rootP != rootQ:
if self.rank[rootP] > self.rank[rootQ]:
self.parent[rootQ] = rootP
elif self.rank[rootP] < self.rank[rootQ]:
self.parent[rootP] = rootQ
self.parent[rootQ] = rootP
self.rank[rootP] += 1
def connected(self, p, q):
return self.find(p) == self.find(q)
class Solution:
def numIslands(self, grid: List[List[str]]) -> int:
if not grid:
return 0
rows = len(grid)
cols = len(grid[0])
# Initialize UnionFind with size rows*cols
uf = UnionFind(rows * cols)
# Count initial land cells and connect adjacent lands
count = 0
for i in range(rows):
for j in range(cols):
if grid[i][j] == "1":
count += 1
# Check right and down neighbors
for di, dj in [(0, 1), (1, 0)]:
ni, nj = i + di, j + dj
if (0 <= ni < rows and
0 <= nj < cols and
grid[ni][nj] == "1"):
if not uf.connected(i*cols + j, ni*cols + nj):
uf.union(i*cols + j, ni*cols + nj)
count -= 1
return count

547. Number of Provinces

class UnionFind:
def __init__(self, size):
self.parent = [i for i in range(size)]
self.rank = [1] * size
def find(self, p):
if self.parent[p] != p:
self.parent[p] = self.find(self.parent[p])
return self.parent[p]
def union(self, p, q):
rootP = self.find(p)
rootQ = self.find(q)
if rootP != rootQ:
if self.rank[rootP] > self.rank[rootQ]:
self.parent[rootQ] = rootP
elif self.rank[rootP] < self.rank[rootQ]:
self.parent[rootP] = rootQ
self.parent[rootQ] = rootP
self.rank[rootP] += 1
def connected(self, p, q):
return self.find(p) == self.find(q)
class Solution:
def findCircleNum(self, isConnected: List[List[int]]) -> int:
if not isConnected:
return 0
n = len(isConnected)
uf = UnionFind(n)
# Connect cities based on the adjacency matrix
for i in range(n):
for j in range(i + 1, n):
if isConnected[i][j] == 1:
uf.union(i, j)
# Count number of distinct provinces
provinces = set()
for i in range(n):
return len(provinces)
