Union-find, sometimes called disjoint-set union, is a data structure that stores equivalence classes quickly and compactly. It has a bunch of uses:
I think about it mostly in from a compilers perspective because that’s all I do all day, every day: rewrite those graphs.
It’s especially interesting to me because the core of the simple implementation is so small:
# Implementation courtesy Phil Zucker
uf = {}
def find(x): # Find the set representative
while x in uf:
x = uf[x]
return x
def union(x, y): # Join two sets
x = find(x)
y = find(y)
if x is not y:
uf[x] = y
This makes it so easy to drop this into any existing or new project. No library, just 10 lines of code. Instant find-and-replace when optimizing your compiler IR.
This implementation does not do path compression or union-by-rank, the features that get peak performance, but adding those features in can be done incrementally and without changing the API.
Another neat thing is that this implementation does not specify what types x
and y
, the elements, are. They may be integers, the usual type, but they
could also be any other type that can be hashed and compared. I think most
people end up using a dense representation—an array—with indices, though.
Or maybe the inline embedded pointers approach (see below).
You should always refer to the set representative when using union-find. This
means you have to treat find
kind of like a read barrier in a garbage
collector: using a stale pointer in an operation may be undefined behavior.
The pointer doesn’t go bad, exactly—holding onto it is fine—you just need
to call find
before doing things to it.
This allows other pieces of your infrastructure—say, a type inference pass—to only store information for and only update the set representative.
…todo about setting representative to self
…todo
…todo
def find(x):
# The same walk as before
result = x
while result in uf:
result = uf[result]
# Walk the chain again, updating each node to point directly to the end
while x in uf:
current = uf[x]
uf[x] = result
x = current
return result
…todo
…todo
For this demonstration of union-find embedded into your existing data structures, we will use a barebones imaginary compiler IR. This may look familiar; it is a mashup of Phil Zucker’s above union-find and CF Bolz-Tereick’s Toy Optimizer union-find.
class Node:
forwarded: Node|None = None
def find(self):
result = self
while result.forwarded is not None:
result = result.forwarded
return result
def union(self, other):
self = self.find()
other = other.find()
if self is not other:
self.forwarded = other
class Add(Node):
...
Fil Pizlo approach; take no extra space
class Node:
def find(self):
result = self
while isinstance(result, Identity):
result = result.forwarded
return result
def union(self, other):
self = self.find()
other = other.find()
if self is not other:
self.__class__ = Identity
self.forwarded = other
class Identity(Node):
forwarded: Node|None = None
class Add(Node):
...
This also works without classes; set opcode
field or something. Just requires
all rewritable nodes be at least as big as Identity
node, but they probably
are due to having other fields, or alignment, or something.