Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions include/tvm/s_tir/meta_schedule/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ class PyCostModelNode : public CostModelNode {
*/
using FPredict = ffi::TypedFunction<void(const TuneContext&, const ffi::Array<MeasureCandidate>&,
void* p_addr)>;
/*!
* \brief Get the cost model as string with name.
* \return The string representation of the cost model.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;

/*! \brief The packed function to the `Load` function. */
FLoad f_load;
/*! \brief The packed function to the `Save` function. */
Expand All @@ -122,8 +116,6 @@ class PyCostModelNode : public CostModelNode {
FUpdate f_update;
/*! \brief The packed function to the `Predict` function. */
FPredict f_predict;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void Load(const ffi::String& path);
void Save(const ffi::String& path);
Expand All @@ -142,19 +134,17 @@ class PyCostModelNode : public CostModelNode {
class CostModel : public ffi::ObjectRef {
public:
/*!
* \brief Create a feature extractor with customized methods on the python-side.
* \brief Create a cost model with customized methods on the python-side.
* \param f_load The packed function of `Load`.
* \param f_save The packed function of `Save`.
* \param f_update The packed function of `Update`.
* \param f_predict The packed function of `Predict`.
* \param f_as_string The packed function of `AsString`.
* \return The feature extractor created.
* \return The cost model created.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a copy-paste typo in the docstring for CostModel::PyCostModel. It refers to a "feature extractor" instead of a "cost model". Please update it to refer to "cost model". Note that the first line of this docstring (which is outside the diff hunk) also contains "Create a feature extractor..." and should be updated to "Create a cost model..." as well.

   * \return The cost model created.

*/
TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, //
PyCostModelNode::FSave f_save, //
PyCostModelNode::FUpdate f_update, //
PyCostModelNode::FPredict f_predict, //
PyCostModelNode::FAsString f_as_string);
TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, //
PyCostModelNode::FSave f_save, //
PyCostModelNode::FUpdate f_update, //
PyCostModelNode::FPredict f_predict);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CostModel, ffi::ObjectRef, CostModelNode);
};

Expand Down
13 changes: 1 addition & 12 deletions include/tvm/s_tir/meta_schedule/feature_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,11 @@ class PyFeatureExtractorNode : public FeatureExtractorNode {
*/
using FExtractFrom = ffi::TypedFunction<ffi::Array<tvm::runtime::Tensor>(
const TuneContext& context, const ffi::Array<MeasureCandidate>& candidates)>;
/*!
* \brief Get the feature extractor as string with name.
* \return The string of the feature extractor.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;

/*! \brief The packed function to the `ExtractFrom` function. */
FExtractFrom f_extract_from;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

static void RegisterReflection() {
// `f_extract_from` is not registered
// `f_as_string` is not registered
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyFeatureExtractorNode>();
}
Expand Down Expand Up @@ -114,12 +105,10 @@ class FeatureExtractor : public ffi::ObjectRef {
/*!
* \brief Create a feature extractor with customized methods on the python-side.
* \param f_extract_from The packed function of `ExtractFrom`.
* \param f_as_string The packed function of `AsString`.
* \return The feature extractor created.
*/
TVM_DLL static FeatureExtractor PyFeatureExtractor(
PyFeatureExtractorNode::FExtractFrom f_extract_from,
PyFeatureExtractorNode::FAsString f_as_string);
PyFeatureExtractorNode::FExtractFrom f_extract_from);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FeatureExtractor, ffi::ObjectRef,
FeatureExtractorNode);
};
Expand Down
13 changes: 1 addition & 12 deletions include/tvm/s_tir/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,11 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
const ffi::Array<MeasureCandidate>& measure_candidates, //
const ffi::Array<BuilderResult>& builds, //
const ffi::Array<RunnerResult>& results)>;
/*!
* \brief Get the measure callback function as string with name.
* \return The string of the measure callback function.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;

/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

static void RegisterReflection() {
// `f_apply` is not registered
// `f_as_string` is not registered
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyMeasureCallbackNode>();
}
Expand Down Expand Up @@ -135,11 +126,9 @@ class MeasureCallback : public ffi::ObjectRef {
/*!
* \brief Create a measure callback with customized methods on the python-side.
* \param f_apply The packed function of `Apply`.
* \param f_as_string The packed function of `AsString`.
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
PyMeasureCallbackNode::FAsString f_as_string);
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply);
/*! \brief The default list of measure callbacks. */
TVM_DLL static ffi::Array<MeasureCallback, void> Default();
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MeasureCallback, ffi::ObjectRef, MeasureCallbackNode);
Expand Down
12 changes: 1 addition & 11 deletions include/tvm/s_tir/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ class Mutator : public ffi::ObjectRef {
* \return The cloned mutator.
*/
using FClone = ffi::TypedFunction<Mutator()>;
/*!
* \brief Get the mutator as string with name.
* \return The string of the mutator.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;
/*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
Expand Down Expand Up @@ -128,11 +123,10 @@ class Mutator : public ffi::ObjectRef {
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The mutator created.
*/
TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context,
FApply f_apply, FClone f_clone, FAsString f_as_string);
FApply f_apply, FClone f_clone);
/*! \brief Create default mutators for LLVM */
TVM_DLL static ffi::Map<Mutator, FloatImm, void> DefaultLLVM();
/*! \brief Create default mutators for CUDA */
Expand All @@ -151,23 +145,19 @@ class PyMutatorNode : public MutatorNode {
using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext;
using FApply = Mutator::FApply;
using FClone = Mutator::FClone;
using FAsString = Mutator::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyMutatorNode>();
// `f_initialize_with_tune_context` is not registered
// `f_apply` is not registered
// `f_clone` is not registered
// `f_as_string` is not registered
}

void InitializeWithTuneContext(const TuneContext& context) final;
Expand Down
13 changes: 1 addition & 12 deletions include/tvm/s_tir/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,16 @@ class Postproc : public ffi::ObjectRef {
* \return The cloned postprocessor.
*/
using FClone = ffi::TypedFunction<Postproc()>;
/*!
* \brief Get the postprocessor function as string with name.
* \return The string of the postprocessor function.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;
/*!
* \brief Create a postprocessor with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The postprocessor created.
*/
TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
FClone f_clone);
/*!
* \brief Create a postprocessor that checks if all loops are static
* \return The postprocessor created
Expand Down Expand Up @@ -186,21 +179,17 @@ class PyPostprocNode : public PostprocNode {
using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
using FApply = Postproc::FApply;
using FClone = Postproc::FClone;
using FAsString = Postproc::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

static void RegisterReflection() {
// `f_initialize_with_tune_context` is not registered
// `f_apply` is not registered
// `f_clone` is not registered
// `f_as_string` is not registered
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyPostprocNode>();
}
Expand Down
13 changes: 1 addition & 12 deletions include/tvm/s_tir/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ class ScheduleRule : public ffi::ObjectRef {
*/
using FApply = ffi::TypedFunction<ffi::Array<s_tir::Schedule>(const s_tir::Schedule&,
const s_tir::SBlockRV&)>;
/*!
* \brief Get the schedule rule as string with name.
* \return The string of the schedule rule.
*/
using FAsString = ffi::TypedFunction<ffi::String()>;
/*!
* \brief The function type of `Clone` method.
* \return The cloned schedule rule.
Expand Down Expand Up @@ -290,14 +285,12 @@ class ScheduleRule : public ffi::ObjectRef {
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
FClone f_clone);

/*! \brief Create default schedule rules for LLVM */
TVM_DLL static ffi::Array<ScheduleRule, void> DefaultLLVM();
Expand All @@ -323,21 +316,17 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
using FApply = ScheduleRule::FApply;
using FClone = ScheduleRule::FClone;
using FAsString = ScheduleRule::FAsString;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;

static void RegisterReflection() {
// `f_initialize_with_tune_context` is not registered
// `f_apply` is not registered
// `f_as_string` is not registered
// `f_clone` is not registered
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PyScheduleRuleNode>();
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def method(*args, **kwargs):
# extract functions that differ from the base class
if not hasattr(base_cls, name):
continue
if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__":
if getattr(base_cls, name) is getattr(inherit_cls, name):
continue
return method

# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
# on the c++ side when you call the method
return None

assert isinstance(cls.__base__, type)
Expand Down
15 changes: 1 addition & 14 deletions python/tvm/s_tir/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import _get_default_str


@register_object("s_tir.meta_schedule.CostModel")
Expand Down Expand Up @@ -169,7 +168,6 @@ def __init__(
f_save: Callable | None = None,
f_update: Callable | None = None,
predict_func: Callable | None = None,
f_as_string: Callable | None = None,
):
"""Constructor."""

Expand All @@ -189,7 +187,6 @@ def f_predict(context: TuneContext, candidates: list[MeasureCandidate], return_p
f_save,
f_update,
f_predict,
f_as_string,
)


Expand All @@ -203,7 +200,7 @@ class PyCostModel:

_tvm_metadata = {
"cls": _PyCostModel,
"methods": ["load", "save", "update", "predict", "__str__"],
"methods": ["load", "save", "update", "predict"],
}

def load(self, path: str) -> None:
Expand Down Expand Up @@ -261,13 +258,3 @@ def predict(self, context: TuneContext, candidates: list[MeasureCandidate]) -> n
The predicted normalized score.
"""
raise NotImplementedError

def __str__(self) -> str:
"""Get the cost model as string with name.

Return
------
result : str
Get the cost model as string with name.
"""
return _get_default_str(self)
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .. import _ffi_api
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import _get_default_str


@register_object("s_tir.meta_schedule.FeatureExtractor")
Expand Down Expand Up @@ -87,13 +86,12 @@ class _PyFeatureExtractor(FeatureExtractor):
See also: PyFeatureExtractor
"""

def __init__(self, f_extract_from: Callable, f_as_string: Callable | None = None):
def __init__(self, f_extract_from: Callable):
"""Constructor."""

self.__init_handle_by_constructor__(
_ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member
f_extract_from,
f_as_string,
)


Expand All @@ -107,7 +105,7 @@ class PyFeatureExtractor:

_tvm_metadata = {
"cls": _PyFeatureExtractor,
"methods": ["extract_from", "__str__"],
"methods": ["extract_from"],
}

def extract_from(
Expand All @@ -128,6 +126,3 @@ def extract_from(
The feature tvm ndarray extracted.
"""
raise NotImplementedError

def __str__(self) -> str:
return _get_default_str(self)
Loading
Loading