Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/collage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
MEASURE_REPEAT,
WARMUP_MIN_REPEAT_MS,
CostEstimator,
MockEstimator,
MockCostEstimator,
)
8 changes: 4 additions & 4 deletions python/tvm/relay/collage/collage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.CostEstimator)


@register_object("relay.collage.MockEstimator")
class MockEstimator(Object):
@register_object("relay.collage.MockCostEstimator")
class MockCostEstimator(Object):
"""MockEstimator class"""

def __init__(self, target_costs):
self.__init_handle_by_constructor__(_ffi_api.MockEstimator, target_costs)
def __init__(self, target_costs, max_estimates=0):
self.__init_handle_by_constructor__(_ffi_api.MockCostEstimator, target_costs, max_estimates)


def arg_for(arg_type, device):
Expand Down
6 changes: 2 additions & 4 deletions src/relay/collage/candidate_partition_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "./candidate_partition_index.h"

#include "./gather_partition_specs.h"
#include "./prune_candidates.h"
#include "./utils.h"

namespace tvm {
Expand All @@ -40,10 +41,7 @@ CandidatePartitionIndex::CandidatePartitionIndex(

void CandidatePartitionIndex::Index(const Array<PartitionSpec>& partition_specs) {
std::vector<CandidatePartition> candidates = Collect(partition_specs);

// (The candidates could be pruned at this point to elliminate those which are heuristically
// unlikely to appear in the optimal partitioning.)

candidates = PruneCandidates(*dataflow_graph_, candidates);
// Index the candidates by their first inside index.
for (auto& candidate : candidates) {
first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back(
Expand Down
75 changes: 2 additions & 73 deletions src/relay/collage/cost_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,21 @@

#include "./cost_estimator.h"

#include <math.h>
#include <tvm/relay/expr_functor.h>
#include <cmath>

namespace tvm {
namespace relay {
namespace collage {

TVM_REGISTER_OBJECT_TYPE(CostEstimatorNode);
TVM_REGISTER_OBJECT_TYPE(MockEstimatorNode);

CostEstimator::CostEstimator() {
auto node = make_object<CostEstimatorNode>();
data_ = std::move(node);
}

Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
// TODO(mbs): Eventually should be abstract. For now bounce to the Python local impl.
static const runtime::PackedFunc* estimate_seconds =
runtime::Registry::Get("tvm.relay.collage.estimate_seconds");
ICHECK(estimate_seconds);
Expand All @@ -53,78 +52,8 @@ Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target) cons
}
}

/*!
* \brief Visitor to accumulate the costs of all calls to operators in an expression.
*/
class MockEstimationVisitor : private ExprVisitor {
public:
MockEstimationVisitor(double op_cost, double fusion_benefit)
: op_cost_(op_cost), fusion_benefit_(fusion_benefit) {}

double EstimateCost(const Expr& body) {
this->VisitExpr(body);
return cost_;
}

private:
/*! \brief The assumed baseline cost of each operator call. */
double op_cost_;
/*!
* \brief The factor by which each operator call cost is to be changed for every other
* operator call in the same group.
*/
double fusion_benefit_;
/*! \brief The number of operator calls seen so far. */
size_t num_ops_ = 0;
/*! \brief Accumulate overall cost. */
double cost_ = 0.0;

void VisitExpr_(const CallNode* call_node) final {
if (call_node->op->IsInstance<OpNode>()) {
cost_ += op_cost_ * pow(fusion_benefit_, num_ops_);
num_ops_++;
}
ExprVisitor::VisitExpr_(call_node);
}

void VisitExpr_(const FunctionNode* function_node) final {
// No "Compiler" functions can be inlined.
ICHECK(!function_node->GetAttr<String>(attr::kCompiler).defined());
ExprVisitor::VisitExpr_(function_node);
}
};

Cost MockEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
double op_cost = static_cast<double>(target_costs_.at(target->kind->name)->value);
double cost = 0.0;
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
auto function = GetRef<Function>(function_node);
if (kv.first->name_hint == "main") {
// Only tensor args are allowed to main.
for (const auto& param : function->params) {
ICHECK(param->type_annotation->IsInstance<TensorTypeNode>());
}
}
cost += MockEstimationVisitor(op_cost, /*fusion_benefit=*/0.9).EstimateCost(function->body);
}
}
return Cost::Value(cost);
}

MockEstimator::MockEstimator(Map<String, Integer> target_costs) {
auto node = make_object<MockEstimatorNode>();
node->target_costs_ = std::move(target_costs);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("relay.collage.CostEstimator").set_body_typed([]() { return CostEstimator(); });

TVM_REGISTER_GLOBAL("relay.collage.MockEstimator")
.set_body_typed([](Map<String, Integer> target_costs) {
return MockEstimator(std::move(target_costs));
});

} // namespace collage
} // namespace relay
} // namespace tvm
33 changes: 0 additions & 33 deletions src/relay/collage/cost_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,6 @@ class CostEstimator : public ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CostEstimator, ObjectRef, CostEstimatorNode);
};

/*!
* \brief A mock cost estimator which can determine the cost of a candidate based on both
* the candidate's target and the number of operator calls inside it.
*
* The estimator also ICHECKs the given module has all "Compiler" functions outlined and @main
* takes only tensor arguments (ie no tuple types).
*
* To support testing only.
*/
class MockEstimatorNode : public CostEstimatorNode {
public:
Cost Estimate(const IRModule& mod, const Target& target) const override;

static constexpr const char* _type_key = "relay.collage.MockEstimator";
TVM_DECLARE_FINAL_OBJECT_INFO(MockEstimatorNode, CostEstimatorNode);

protected:
friend class MockEstimator;

/*!
* \brief Map from target kind name to assumed baseline cost (in integer seconds) for all
* operator calls.
*/
Map<String, Integer> target_costs_;
};

class MockEstimator : public CostEstimator {
public:
explicit MockEstimator(Map<String, Integer> target_costs);

TVM_DEFINE_OBJECT_REF_METHODS(MockEstimator, CostEstimator, MockEstimatorNode);
};

} // namespace collage
} // namespace relay
} // namespace tvm
Expand Down
121 changes: 121 additions & 0 deletions src/relay/collage/mock_cost_estimator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/collage/mock_cost_estimator.cc
* \brief A mock CostEstimator to support unit tests.
*/

#include "./mock_cost_estimator.h"

#include <tvm/relay/expr_functor.h>

namespace tvm {
namespace relay {
namespace collage {

TVM_REGISTER_OBJECT_TYPE(MockCostEstimatorNode);

namespace {

/*!
* \brief Visitor to accumulate the costs of all calls to operators in an expression.
*/
class MockEstimationVisitor : private ExprVisitor {
public:
MockEstimationVisitor(double op_cost, double fusion_benefit)
: op_cost_(op_cost), fusion_benefit_(fusion_benefit) {}

double EstimateCost(const Expr& body) {
VisitExpr(body);
return cost_;
}

private:
/*! \brief The assumed baseline cost of each operator call. */
double op_cost_;
/*!
* \brief The factor by which each operator call cost is to be changed for every other
* operator call in the same group.
*/
double fusion_benefit_;
/*! \brief The number of operator calls seen so far. */
size_t num_ops_ = 0;
/*! \brief Accumulate overall cost. */
double cost_ = 0.0;

void VisitExpr_(const CallNode* call_node) final {
if (call_node->op->IsInstance<OpNode>()) {
// Account for number of ops seens os far.
cost_ += op_cost_ * pow(fusion_benefit_, static_cast<double>(num_ops_));
num_ops_++;
}
ExprVisitor::VisitExpr_(call_node);
}

void VisitExpr_(const FunctionNode* function_node) final {
// No "Compiler" functions can be inlined.
ICHECK(!function_node->GetAttr<String>(attr::kCompiler).defined())
<< "All Compiler functions should have been outlined when preparing to estimate costs";
ExprVisitor::VisitExpr_(function_node);
}
};

} // namespace

Cost MockCostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
// Limit the number of estimations.
ICHECK(max_estimates_->value == 0 || num_estimates_ < static_cast<size_t>(max_estimates_->value))
<< "At most " << max_estimates_->value
<< " non-trivial distinct candidates should have been generated.";
++num_estimates_;
double op_cost = static_cast<double>(target_costs_.at(target->kind->name)->value);
double cost = 0.0;
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
auto function = GetRef<Function>(function_node);
if (kv.first->name_hint == "main") {
// Only tensor args are allowed to main.
for (const auto& param : function->params) {
ICHECK(param->type_annotation->IsInstance<TensorTypeNode>())
<< "Any tuple-of-tensor arguments should have been eta-exanded when preparing to "
"estimate costs";
}
}
cost += MockEstimationVisitor(op_cost, /*fusion_benefit=*/0.9).EstimateCost(function->body);
}
}
return Cost::Value(cost);
}

MockCostEstimator::MockCostEstimator(Map<String, Integer> target_costs, Integer max_estimates) {
auto node = make_object<MockCostEstimatorNode>();
node->target_costs_ = std::move(target_costs);
node->max_estimates_ = std::move(max_estimates);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("relay.collage.MockCostEstimator")
.set_body_typed([](Map<String, Integer> target_costs, Integer max_estimates) {
return MockCostEstimator(std::move(target_costs), std::move(max_estimates));
});

} // namespace collage
} // namespace relay
} // namespace tvm
84 changes: 84 additions & 0 deletions src/relay/collage/mock_cost_estimator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/collage/mock_cost_estimator.cc
* \brief A mock CostEstimator to support unit tests.
*/

#ifndef TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_
#define TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_

#include <tvm/relay/function.h>

#include "./cost.h"
#include "./cost_estimator.h"

namespace tvm {
namespace relay {
namespace collage {

/*!
* \brief A mock cost estimator which can determine the cost of a candidate based on both
* the candidate's target and the number of operator calls inside it.
*
* The help unit tests the estimator also ICHECK fails if:
* - the module has inlined "Compiler" functions
* - @main has non-tensor arguments (eg a tuple)
* - more than the given number of candidate modules are measured
*
* To support unit testing only.
*/
class MockCostEstimatorNode : public CostEstimatorNode {
public:
Cost Estimate(const IRModule& mod, const Target& target) const override;

static constexpr const char* _type_key = "relay.collage.MockCostEstimator";
TVM_DECLARE_FINAL_OBJECT_INFO(MockCostEstimatorNode, CostEstimatorNode);

protected:
/*!
* \brief Map from target kind name to assumed baseline cost (in integer seconds) for all
* operator calls.
*/
Map<String, Integer> target_costs_;

/*!
* \brief If non-zero, the maximum number of distinct modules which may be estimated.
*/
Integer max_estimates_;

/*! \brief Number of calls to Estimate. */
mutable size_t num_estimates_ = 0;

friend class MockCostEstimator;
};

class MockCostEstimator : public CostEstimator {
public:
explicit MockCostEstimator(Map<String, Integer> target_costs, Integer max_estimates = 0);

TVM_DEFINE_OBJECT_REF_METHODS(MockCostEstimator, CostEstimator, MockCostEstimatorNode);
};

} // namespace collage
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_
Loading