Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions include/dsu.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#pragma once
#include <vector>
#include <atomic>
#include <cassert>
#include <vector>

template<class T>
template <class T>
struct DSUMergeRet {
bool merged; // true if a merge actually occured
T root; // new root
T child; // new child of root
bool merged; // true if a merge actually occured
T root; // new root
T child; // new child of root
};

template <class T>
Expand All @@ -18,7 +18,24 @@ class DisjointSetUnion {
// parent and size arrays
T* parent;
T* size;
public:

// Order based on size and break ties with
// a simple fixed ordering of the edge
// if sum even, smaller first
// if sum odd, larger first
inline void order_edge(T& a, T& b) {
if (size[a] < size[b])
std::swap(a, b);
else if (size[a] == size[b]) {
if ((a + b) % 2 == 0 && a > b) {
std::swap(a, b);
} else if ((a + b) % 2 == 1 && a < b) {
std::swap(a, b);
}
}
}

public:
DisjointSetUnion(T n) : n(n), parent(new T[n]), size(new T[n]) {
for (T i = 0; i < n; i++) {
parent[i] = i;
Expand All @@ -32,25 +49,25 @@ class DisjointSetUnion {
}

// make a copy of the DSU
DisjointSetUnion(const DisjointSetUnion &oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
DisjointSetUnion(const DisjointSetUnion& oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
for (T i = 0; i < n; i++) {
parent[i] = oth.parent[i];
size[i] = oth.size[i];
}
}

// move the DSU to a new object
DisjointSetUnion(DisjointSetUnion &&oth) : n(oth.n), parent(oth.parent), size(oth.size) {
DisjointSetUnion(DisjointSetUnion&& oth) : n(oth.n), parent(oth.parent), size(oth.size) {
oth.n = 0;
oth.parent = nullptr;
oth.size = nullptr;
}

DisjointSetUnion operator=(const DisjointSetUnion &oth) = delete;
DisjointSetUnion operator=(const DisjointSetUnion& oth) = delete;

inline T find_root(T u) {
assert(0 <= u && u < n);
while(parent[parent[u]] != u) {
while (parent[parent[u]] != u) {
parent[u] = parent[parent[u]];
u = parent[u];
}
Expand All @@ -64,7 +81,7 @@ class DisjointSetUnion {
assert(0 <= b && b < n);
if (a == b) return {false, 0, 0};

if (size[a] < size[b]) std::swap(a,b);
order_edge(a, b);
parent[b] = a;
size[a] += size[b];
return {true, a, b};
Expand All @@ -82,14 +99,31 @@ class DisjointSetUnion {
// thus is a little slower for single threaded use cases
template <class T>
class DisjointSetUnion_MT {
private:
private:
// number of items in the DSU
T n;

// parent and size arrays
std::atomic<T>* parent;
std::atomic<T>* size;
public:

// Order based on size and break ties with
// a simple fixed ordering of the edge
// if sum even, smaller first
// if sum odd, larger first
inline void order_edge(T& a, T& b) {
if (size[a] < size[b])
std::swap(a, b);
else if (size[a] == size[b]) {
if ((a + b) % 2 == 0 && a > b) {
std::swap(a, b);
} else if ((a + b) % 2 == 1 && a < b) {
std::swap(a, b);
}
}
}

public:
DisjointSetUnion_MT(T n) : n(n), parent(new std::atomic<T>[n]), size(new std::atomic<T>[n]) {
for (T i = 0; i < n; i++) {
parent[i] = i;
Expand All @@ -103,15 +137,15 @@ class DisjointSetUnion_MT {
}

// make a copy of the DSU
DisjointSetUnion_MT(const DisjointSetUnion_MT &oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
DisjointSetUnion_MT(const DisjointSetUnion_MT& oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
for (T i = 0; i < n; i++) {
parent[i] = oth.parent[i].load();
size[i] = oth.size[i].load();
}
}

// move the DSU to a new object
DisjointSetUnion_MT(DisjointSetUnion_MT &&oth) : n(oth.n), parent(oth.parent), size(oth.size) {
DisjointSetUnion_MT(DisjointSetUnion_MT&& oth) : n(oth.n), parent(oth.parent), size(oth.size) {
oth.n = 0;
oth.parent = nullptr;
oth.size = nullptr;
Expand All @@ -131,8 +165,7 @@ class DisjointSetUnion_MT {
while ((a = find_root(a)) != (b = find_root(b))) {
assert(0 <= a && a < n);
assert(0 <= b && b < n);
if (size[a] < size[b])
std::swap(a, b);
order_edge(a, b);

// if parent of b has not been modified by another thread -> replace with a
if (parent[b].compare_exchange_weak(b, a)) {
Expand Down