diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index df896cb690eb..b2776a41c50c 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -609,6 +609,8 @@ void PatternGrouper::VisitExprs() { } void PatternGrouper::CreateGroup(const Expr& expr) { + VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr); + int var_number = 0; auto node_map = matcher_->GetMemo(); @@ -696,6 +698,7 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto body = extractor.Mutate(expr); group.function = Function(params, body, NullValue(), Array()); + VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function); group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group or creating an invalid fusion // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the @@ -708,6 +711,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // Similiarly, if interior nodes in a group are used outside of the group fusing to a single // output would create an invalid graph tranformation, so we block the creation of such groups. auto memo = extractor.GetMemo(); + for (auto kv : memo) { + VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_; + } + for (auto kv : memo) { // Check to ensure that this node isn't an input or a global if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && @@ -720,16 +727,19 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // if the node isn't the output of the group auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group if (memo.count(output->ref()) == 0) { - // TODO(mbs): This condition used to also include the following test, which since - // the dominators relation is used back-to-front was always vacuously true. So the - // code is just rejecting the match if a strictly internal node happened to connect - // to an outside node. - ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); - // Exit because nodes in this pattern's body are used outside the pattern, fusing it - // would be invalid - return; + // A node inside the matched group contributes an output to nodes outside of the matched + // group... + auto root = matcher_->expr_to_node(expr); + if (!root->Dominates(output)) { + // ...and the outside dataflow does not come back to the root of the matched group. + // So reject the match since it would create a cycle. + VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_ + << " for root " << root->index_ << " in graph:" << std::endl + << matcher_->expr_graph().ToString(); + return; + } + // else: We'll allow the output to be included in the matched group. } } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index f04190f72e40..a174d8e34eb7 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -55,6 +55,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual>& memo() const { return memo_; } + const IndexedGraph& expr_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index c10597940221..8e5238b17399 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -941,4 +941,4 @@ def test_conv2d_bwd(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f0474c911273..ba066e9a438f 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1458,7 +1458,6 @@ def concat(*args): def test_partition_fuzzy_function_args(): - func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard() x = relay.var("x") y = relay.var("y") @@ -1790,5 +1789,56 @@ def callback(self, pre, post, node_map): assert tvm.ir.structural_equal(out, expected) +def test_matched_outside_but_dominated(): + """In this example the pattern matches the nn.conv2d/add/multiply flow. Even though the + add output is consumed by the sigmoid, the sigmoid itself is dominated by the multiply. + So partitioning can proceed, all be it with a duplication of the add.""" + in_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] { + %0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC"); + %1 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI"); + %2 = expand_dims(%bias, axis=1, num_newaxis=2); + %3 = expand_dims(%2, axis=0); + %4 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %5 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC"); + %6 = add(%4, %5); + %7 = sigmoid(%6); + %8 = multiply(%6, %7); + layout_transform(%8, src_layout="NHWC", dst_layout="NCHW") + } + """ + ) + expected_mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] { + %2 = expand_dims(%bias, axis=1, num_newaxis=2); + %3 = expand_dims(%2, axis=0); + %4 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC"); + %5 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI"); + %6 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %7 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC"); + %8 = add(%6, %7); + %9 = sigmoid(%8); + %10 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, %FunctionVar_0_3, PartitionedFromPattern="nn.conv2d_add_multiply_") { + %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32"); + %1 = add(%0, %FunctionVar_0_2); + multiply(%1, %FunctionVar_0_3) + }; + %11 = %10(%4, %5, %7, %9); + layout_transform(%11, src_layout="NHWC", dst_layout="NCHW") + } + """ + ) + pattern = is_op("multiply")( + is_op("add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard()), wildcard() + ) + actual_mod = tvm.IRModule.from_expr(pattern.partition(in_mod["main"])) + actual_mod = relay.transform.InferType()(actual_mod) + tvm.ir.assert_structural_equal(actual_mod, expected_mod) + + if __name__ == "__main__": tvm.testing.main()