diff --git a/include/l0_sampling/sketch.h b/include/l0_sampling/sketch.h index fe5cc863..81bc5701 100644 --- a/include/l0_sampling/sketch.h +++ b/include/l0_sampling/sketch.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "../types.h" #include "../util.h" @@ -130,7 +131,7 @@ class Sketch { * Function to query all columns within a sketch to return 1 or more non-zero indices * @return A pair with the result indices and a code indicating the type of result. */ - std::pair, SampleSketchRet> exhaustive_query(); + std::pair, SampleSketchRet> exhaustive_query(); inline uint64_t get_seed() const { return seed; } inline size_t column_seed(size_t column_idx) const { return seed + column_idx*5; } diff --git a/include/supernode.h b/include/supernode.h index 788bf4ac..800ee49c 100644 --- a/include/supernode.h +++ b/include/supernode.h @@ -152,7 +152,7 @@ class Supernode { * if one exists. Additionally, returns a code represnting the sample * result (good, zero, or fail) */ - std::pair, SampleSketchRet> exhaustive_sample(); + std::pair, SampleSketchRet> exhaustive_sample(); /** * In-place merge function. Guaranteed to update the caller Supernode. diff --git a/include/types.h b/include/types.h index e65f572d..dc7cddf3 100644 --- a/include/types.h +++ b/include/types.h @@ -24,7 +24,19 @@ struct Edge { return dst < oth.dst; return src < oth.src; } + bool operator== (const Edge&oth) const { + return src == oth.src && dst == oth.dst; + } }; +namespace std { + template <> + struct hash { + auto operator()(const Edge&edge) const -> size_t { + std::hash h; + return h(edge.dst) + (31 * h(edge.src)); + } + }; +} struct GraphUpdate { Edge edge; diff --git a/src/l0_sampling/sketch.cpp b/src/l0_sampling/sketch.cpp index 407f4009..e1c74d14 100644 --- a/src/l0_sampling/sketch.cpp +++ b/src/l0_sampling/sketch.cpp @@ -115,17 +115,17 @@ std::pair Sketch::query() { return {0, FAIL}; } -std::pair, SampleSketchRet> Sketch::exhaustive_query() { +std::pair, SampleSketchRet> Sketch::exhaustive_query() { unlikely_if (already_queried) throw MultipleQueryException(); - std::vector ret; + std::unordered_set ret; unlikely_if (bucket_a[num_elems - 1] == 0 && bucket_c[num_elems - 1] == 0) return {ret, ZERO}; // the "first" bucket is deterministic so if zero then no edges to return unlikely_if ( Bucket_Boruvka::is_good(bucket_a[num_elems - 1], bucket_c[num_elems - 1], checksum_seed())) { - ret.push_back(bucket_a[num_elems - 1]); + ret.insert(bucket_a[num_elems - 1]); return {ret, GOOD}; } for (unsigned i = 0; i < num_columns; ++i) { @@ -133,8 +133,7 @@ std::pair, SampleSketchRet> Sketch::exhaustive_query() { unsigned bucket_id = i * num_guesses + j; unlikely_if ( Bucket_Boruvka::is_good(bucket_a[bucket_id], bucket_c[bucket_id], checksum_seed())) { - ret.push_back(bucket_a[bucket_id]); - update(bucket_a[bucket_id]); + ret.insert(bucket_a[bucket_id]); } } } diff --git a/src/supernode.cpp b/src/supernode.cpp index d9f45057..13e96fb4 100644 --- a/src/supernode.cpp +++ b/src/supernode.cpp @@ -93,13 +93,14 @@ std::pair Supernode::sample() { return {inv_concat_pairing_fn(non_zero), ret_code}; } -std::pair, SampleSketchRet> Supernode::exhaustive_sample() { +std::pair, SampleSketchRet> Supernode::exhaustive_sample() { if (out_of_queries()) throw OutOfQueriesException(); - std::pair, SampleSketchRet> query_ret = get_sketch(sample_idx++)->exhaustive_query(); - std::vector edges(query_ret.first.size()); - for (size_t i = 0; i < edges.size(); i++) - edges[i] = inv_concat_pairing_fn(query_ret.first[i]); + std::pair, SampleSketchRet> query_ret = get_sketch(sample_idx++)->exhaustive_query(); + std::unordered_set edges(query_ret.first.size()); + for (const auto &query_item: query_ret.first) { + edges.insert(inv_concat_pairing_fn(query_item)); + } SampleSketchRet ret_code = query_ret.second; return {edges, ret_code}; diff --git a/test/sketch_test.cpp b/test/sketch_test.cpp index 13de1e0f..f7b6fdab 100644 --- a/test/sketch_test.cpp +++ b/test/sketch_test.cpp @@ -358,7 +358,7 @@ TEST(SketchTestSuite, TestExhaustiveQuery) { sketch->update(9); sketch->update(10); - std::pair, SampleSketchRet> query_ret = sketch->exhaustive_query(); + std::pair, SampleSketchRet> query_ret = sketch->exhaustive_query(); if (query_ret.second != GOOD) { ASSERT_EQ(query_ret.first.size(), 0) << query_ret.second; } diff --git a/test/supernode_test.cpp b/test/supernode_test.cpp index 1bdc25ce..35959ff6 100644 --- a/test/supernode_test.cpp +++ b/test/supernode_test.cpp @@ -288,7 +288,7 @@ TEST_F(SupernodeTestSuite, ExhaustiveSample) { // do 4 samples for (size_t i = 0; i < 4; i++) { - std::pair, SampleSketchRet> query_ret = s_node->exhaustive_sample(); + std::pair, SampleSketchRet> query_ret = s_node->exhaustive_sample(); if (query_ret.second != GOOD) { ASSERT_EQ(query_ret.first.size(), 0); }