11#pragma once
22#include < vector>
33#include < atomic>
4+ #include < cassert>
45
56template <class T >
67struct DSUMergeRet {
@@ -11,16 +12,44 @@ struct DSUMergeRet {
1112
1213template <class T >
1314class DisjointSetUnion {
14- std::vector<T> parent;
15- std::vector<T> size;
15+ // number of items in the DSU
16+ T n;
17+
18+ // parent and size arrays
19+ T* parent;
20+ T* size;
1621public:
17- DisjointSetUnion (T n) : parent (n), size(n, 1 ) {
22+ DisjointSetUnion (T n) : n (n), parent( new T[n]), size( new T[n] ) {
1823 for (T i = 0 ; i < n; i++) {
1924 parent[i] = i;
25+ size[i] = 1 ;
26+ }
27+ }
28+
29+ ~DisjointSetUnion () {
30+ delete[] parent;
31+ delete[] size;
32+ }
33+
34+ // make a copy of the DSU
35+ DisjointSetUnion (const DisjointSetUnion &oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
36+ for (T i = 0 ; i < n; i++) {
37+ parent[i] = oth.parent [i];
38+ size[i] = oth.size [i];
2039 }
2140 }
2241
42+ // move the DSU to a new object
43+ DisjointSetUnion (DisjointSetUnion &&oth) : n(oth.n), parent(oth.parent), size(oth.size) {
44+ oth.n = 0 ;
45+ oth.parent = nullptr ;
46+ oth.size = nullptr ;
47+ }
48+
49+ DisjointSetUnion operator =(const DisjointSetUnion &oth) = delete ;
50+
2351 inline T find_root (T u) {
52+ assert (0 <= u && u < n);
2453 while (parent[parent[u]] != u) {
2554 parent[u] = parent[parent[u]];
2655 u = parent[u];
@@ -31,6 +60,8 @@ class DisjointSetUnion {
3160 inline DSUMergeRet<T> merge (T u, T v) {
3261 T a = find_root (u);
3362 T b = find_root (v);
63+ assert (0 <= a && a < n);
64+ assert (0 <= b && b < n);
3465 if (a == b) return {false , 0 , 0 };
3566
3667 if (size[a] < size[b]) std::swap (a,b);
@@ -40,7 +71,7 @@ class DisjointSetUnion {
4071 }
4172
4273 inline void reset () {
43- for (T i = 0 ; i < parent. size () ; i++) {
74+ for (T i = 0 ; i < n ; i++) {
4475 parent[i] = i;
4576 size[i] = 1 ;
4677 }
@@ -50,41 +81,70 @@ class DisjointSetUnion {
5081// Disjoint set union that uses atomics to be thread safe
5182// thus is a little slower for single threaded use cases
5283template <class T >
53- class MT_DisjoinSetUnion {
84+ class DisjointSetUnion_MT {
5485private:
55- std::vector<std::atomic<T>> parent;
56- std::vector<T> size;
86+ // number of items in the DSU
87+ T n;
88+
89+ // parent and size arrays
90+ std::atomic<T>* parent;
91+ std::atomic<T>* size;
5792public:
58- MT_DisjoinSetUnion (T n) : parent (n), size(n, 1 ) {
59- for (T i = 0 ; i < n; i++)
93+ DisjointSetUnion_MT (T n) : n (n), parent( new std::atomic<T>[n]), size( new std::atomic<T>[n] ) {
94+ for (T i = 0 ; i < n; i++) {
6095 parent[i] = i;
96+ size[i] = 1 ;
97+ }
98+ }
99+
100+ ~DisjointSetUnion_MT () {
101+ delete[] parent;
102+ delete[] size;
103+ }
104+
105+ // make a copy of the DSU
106+ DisjointSetUnion_MT (const DisjointSetUnion_MT &oth) : n(oth.n), parent(new T[n]), size(new T[n]) {
107+ for (T i = 0 ; i < n; i++) {
108+ parent[i] = oth.parent [i].load ();
109+ size[i] = oth.size [i].load ();
110+ }
111+ }
112+
113+ // move the DSU to a new object
114+ DisjointSetUnion_MT (DisjointSetUnion_MT &&oth) : n(oth.n), parent(oth.parent), size(oth.size) {
115+ oth.n = 0 ;
116+ oth.parent = nullptr ;
117+ oth.size = nullptr ;
61118 }
62119
63120 inline T find_root (T u) {
121+ assert (0 <= u && u < n);
64122 while (parent[parent[u]] != u) {
65- parent[u] = parent[parent[u]];
123+ parent[u] = parent[parent[u]]. load () ;
66124 u = parent[u];
67125 }
68126 return u;
69127 }
70128
71129 // use CAS in this function to allow for simultaneous merge calls
72- inline DSUMergeRet<T> merge (T u, T v) {
73- while ((u = find_root (u)) != (v = find_root (v))) {
74- if (size[u] < size[v])
75- std::swap (u, v);
130+ inline DSUMergeRet<T> merge (T a, T b) {
131+ while ((a = find_root (a)) != (b = find_root (b))) {
132+ assert (0 <= a && a < n);
133+ assert (0 <= b && b < n);
134+ if (size[a] < size[b])
135+ std::swap (a, b);
76136
77137 // if parent of b has not been modified by another thread -> replace with a
78- if (std::atomic_compare_exchange_weak (& parent[u], &v, u )) {
79- size[u ] += size[v ];
80- return {true , u, v };
138+ if (parent[b]. compare_exchange_weak (b, a )) {
139+ size[a ] += size[b ];
140+ return {true , a, b };
81141 }
82142 }
83143 return {false , 0 , 0 };
84144 }
85145
86146 inline void reset () {
87- for (T i = 0 ; i < parent. size () ; i++) {
147+ for (T i = 0 ; i < n ; i++) {
88148 parent[i] = i;
89149 size[i] = 1 ;
90150 }
0 commit comments