diff --git a/include/dsu.h b/include/dsu.h index d59c0b12..ef38d451 100644 --- a/include/dsu.h +++ b/include/dsu.h @@ -1,13 +1,13 @@ #pragma once -#include #include #include +#include -template +template struct DSUMergeRet { - bool merged; // true if a merge actually occured - T root; // new root - T child; // new child of root + bool merged; // true if a merge actually occured + T root; // new root + T child; // new child of root }; template @@ -18,7 +18,24 @@ class DisjointSetUnion { // parent and size arrays T* parent; T* size; -public: + + // Order based on size and break ties with + // a simple fixed ordering of the edge + // if sum even, smaller first + // if sum odd, larger first + inline void order_edge(T& a, T& b) { + if (size[a] < size[b]) + std::swap(a, b); + else if (size[a] == size[b]) { + if ((a + b) % 2 == 0 && a > b) { + std::swap(a, b); + } else if ((a + b) % 2 == 1 && a < b) { + std::swap(a, b); + } + } + } + + public: DisjointSetUnion(T n) : n(n), parent(new T[n]), size(new T[n]) { for (T i = 0; i < n; i++) { parent[i] = i; @@ -32,7 +49,7 @@ class DisjointSetUnion { } // make a copy of the DSU - DisjointSetUnion(const DisjointSetUnion &oth) : n(oth.n), parent(new T[n]), size(new T[n]) { + DisjointSetUnion(const DisjointSetUnion& oth) : n(oth.n), parent(new T[n]), size(new T[n]) { for (T i = 0; i < n; i++) { parent[i] = oth.parent[i]; size[i] = oth.size[i]; @@ -40,17 +57,17 @@ class DisjointSetUnion { } // move the DSU to a new object - DisjointSetUnion(DisjointSetUnion &&oth) : n(oth.n), parent(oth.parent), size(oth.size) { + DisjointSetUnion(DisjointSetUnion&& oth) : n(oth.n), parent(oth.parent), size(oth.size) { oth.n = 0; oth.parent = nullptr; oth.size = nullptr; } - DisjointSetUnion operator=(const DisjointSetUnion &oth) = delete; + DisjointSetUnion operator=(const DisjointSetUnion& oth) = delete; inline T find_root(T u) { assert(0 <= u && u < n); - while(parent[parent[u]] != u) { + while (parent[parent[u]] != u) { parent[u] = parent[parent[u]]; u = parent[u]; } @@ -64,7 +81,7 @@ class DisjointSetUnion { assert(0 <= b && b < n); if (a == b) return {false, 0, 0}; - if (size[a] < size[b]) std::swap(a,b); + order_edge(a, b); parent[b] = a; size[a] += size[b]; return {true, a, b}; @@ -82,14 +99,31 @@ class DisjointSetUnion { // thus is a little slower for single threaded use cases template class DisjointSetUnion_MT { -private: + private: // number of items in the DSU T n; // parent and size arrays std::atomic* parent; std::atomic* size; -public: + + // Order based on size and break ties with + // a simple fixed ordering of the edge + // if sum even, smaller first + // if sum odd, larger first + inline void order_edge(T& a, T& b) { + if (size[a] < size[b]) + std::swap(a, b); + else if (size[a] == size[b]) { + if ((a + b) % 2 == 0 && a > b) { + std::swap(a, b); + } else if ((a + b) % 2 == 1 && a < b) { + std::swap(a, b); + } + } + } + + public: DisjointSetUnion_MT(T n) : n(n), parent(new std::atomic[n]), size(new std::atomic[n]) { for (T i = 0; i < n; i++) { parent[i] = i; @@ -103,7 +137,7 @@ class DisjointSetUnion_MT { } // make a copy of the DSU - DisjointSetUnion_MT(const DisjointSetUnion_MT &oth) : n(oth.n), parent(new T[n]), size(new T[n]) { + DisjointSetUnion_MT(const DisjointSetUnion_MT& oth) : n(oth.n), parent(new T[n]), size(new T[n]) { for (T i = 0; i < n; i++) { parent[i] = oth.parent[i].load(); size[i] = oth.size[i].load(); @@ -111,7 +145,7 @@ class DisjointSetUnion_MT { } // move the DSU to a new object - DisjointSetUnion_MT(DisjointSetUnion_MT &&oth) : n(oth.n), parent(oth.parent), size(oth.size) { + DisjointSetUnion_MT(DisjointSetUnion_MT&& oth) : n(oth.n), parent(oth.parent), size(oth.size) { oth.n = 0; oth.parent = nullptr; oth.size = nullptr; @@ -131,8 +165,7 @@ class DisjointSetUnion_MT { while ((a = find_root(a)) != (b = find_root(b))) { assert(0 <= a && a < n); assert(0 <= b && b < n); - if (size[a] < size[b]) - std::swap(a, b); + order_edge(a, b); // if parent of b has not been modified by another thread -> replace with a if (parent[b].compare_exchange_weak(b, a)) {