Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""
from .space_generator import SpaceGenerator, PySpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
from .schedule_fn import ScheduleFn
from .post_order_apply import PostOrderApply
from .schedule_fn import SCH_FN_TYPE, ScheduleFn
from .space_generator import PySpaceGenerator, SpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
13 changes: 6 additions & 7 deletions python/tvm/meta_schedule/space_generator/schedule_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,17 @@
if TYPE_CHECKING:
from ..tune_context import TuneContext

SCH_FN_TYPE = Union[ # pylint: disable=invalid-name
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]


@derived_object
class ScheduleFn(PySpaceGenerator):
"""A design space generator with design spaces specified by a schedule function."""

# Multiple cases of schedule functions supported
SCH_FN_TYPE = Union[
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]

def __init__(self, sch_fn: SCH_FN_TYPE):
"""Constructor.

Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TuneConfig(NamedTuple):
search_strategy_config: Optional[Dict[str, Any]] = None
logger_config: Optional[Dict[str, Any]] = None

def create_strategy(self, **kwargs):
def create_strategy(self):
"""Create search strategy from configuration"""
cls_tbl = {
"evolutionary": EvolutionarySearch,
Expand All @@ -111,7 +111,6 @@ def create_strategy(self, **kwargs):
return cls_tbl[self.strategy](
num_trials_per_iter=self.num_trials_per_iter,
max_trials_per_task=max_trials_per_task,
**kwargs,
**config,
)

Expand Down
103 changes: 77 additions & 26 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Meta Schedule tuning context."""

import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from tvm import IRModule
from tvm._ffi import register_object
Expand All @@ -36,7 +36,8 @@
from .runner import RunnerResult
from .schedule_rule import ScheduleRule
from .search_strategy import MeasureCandidate, SearchStrategy
from .space_generator import SpaceGenerator
from .space_generator import SCH_FN_TYPE, ScheduleFn, SpaceGenerator
from .tune import TuneConfig


@register_object("meta_schedule.TuneContext")
Expand All @@ -54,16 +55,24 @@ class TuneContext(Object):
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
space_generator : Optional[SpaceGenerator] = None
space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None
The design space generator.
search_strategy : Optional[SearchStrategy] = None
search_strategy : Union[None, TuneConfig, SearchStrategy] = None
The search strategy.
sch_rules: Optional[List[ScheduleRule]] = None,
if None, the strategy is left blank.
If TuneConfig, the strategy is initialized with the TuneConfig.create_strategy().
sch_rules: Union[None, str, List[ScheduleRule]] = None,
The schedule rules.
postprocs: Optional[List[Postproc"]] = None,
If None, use an empty list of rules.
if "default", use target-default rules.
postprocs: Union[None, str, List[Postproc"]] = None,
The postprocessors.
mutator_probs: Optional[Dict[Mutator, float]]
If None, use an empty list of rules.
if "default", use target-default rules.
mutator_probs: Union[None, str, Dict[Mutator, float]]
Mutators and their probability mass.
If None, use an empty list of rules.
if "default", use target-default rules.
task_name : Optional[str] = None
The name of the tuning task.
logger : logging.Logger
Expand Down Expand Up @@ -99,24 +108,53 @@ def __init__(
mod: Optional[IRModule] = None,
*,
target: Optional[Target] = None,
space_generator: Optional["SpaceGenerator"] = None,
search_strategy: Optional["SearchStrategy"] = None,
sch_rules: Optional[List["ScheduleRule"]] = None,
postprocs: Optional[List["Postproc"]] = None,
mutator_probs: Optional[Dict["Mutator", float]] = None,
space_generator: Union[None, "SCH_FN_TYPE", "ScheduleFn", "SpaceGenerator"] = None,
search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None,
sch_rules: Union[None, str, List["ScheduleRule"]] = None,
postprocs: Union[None, str, List["Postproc"]] = None,
mutator_probs: Union[None, str, Dict["Mutator", float]] = None,
task_name: str = "main",
logger: Optional[logging.Logger] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
):
# pylint: disable=import-outside-toplevel
from . import default_config
from .space_generator import ScheduleFn
from .tune import TuneConfig

# pylint: enable=import-outside-toplevel
if isinstance(mod, PrimFunc):
mod = IRModule.from_expr(mod)
if num_threads is None:
num_threads = cpu_count()
if callable(space_generator):
space_generator = ScheduleFn(space_generator)
if isinstance(search_strategy, TuneConfig):
search_strategy = search_strategy.create_strategy()

@Kathryn-cat Kathryn-cat Jun 15, 2022

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess create_strategy hasn't been implemented yet in search strategy?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops sorry! It's good for now.

if isinstance(sch_rules, str):
if sch_rules == "default":
if target is None:
raise ValueError("target is required when sch_rules is 'default'")
sch_rules = default_config.schedule_rules(None, target)
else:
raise ValueError("sch_rules should be a list of ScheduleRule or 'default'")
if isinstance(postprocs, str):
if postprocs == "default":
if target is None:
raise ValueError("target is required when postprocs is 'default'")
postprocs = default_config.postproc(None, target)
else:
raise ValueError("postprocs should be a list of Postproc or 'default'")
if isinstance(mutator_probs, str):
if mutator_probs == "default":
if target is None:
raise ValueError("target is required when mutator_probs is 'default'")
mutator_probs = default_config.mutator_probs(None, target)
if logger is None:
self.logger = logging.getLogger(__name__)
else:
self.logger = None
if num_threads is None:
num_threads = cpu_count()
self.__init_handle_by_constructor__(
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
mod,
Expand All @@ -131,9 +169,6 @@ def __init__(
rand_state,
num_threads,
)

def initialize(self):
"""Initialize the tuning context"""
_ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member

def generate_design_space(self) -> List[Schedule]:
Expand All @@ -157,7 +192,7 @@ def generate_design_space(self) -> List[Schedule]:

def pre_tuning(
self,
design_spaces: List[Schedule],
design_spaces: Optional[List[Schedule]] = None,
database: Optional["Database"] = None,
cost_model: Optional["CostModel"] = None,
) -> None:
Expand All @@ -167,18 +202,38 @@ def pre_tuning(

Parameters
----------
design_spaces : List[Schedule]
design_spaces : Optional[List[Schedule]]
The design spaces used during tuning process.
If None, use the outcome of `self.generate_design_space()`.
database : Optional[Database] = None
The database used during tuning process.
If None, and the search strategy is `EvolutionarySearch`,
then use `tvm.meta_schedule.database.MemoryDatabase`.
cost_model : Optional[CostModel] = None
The cost model used during tuning process.
If None, and the search strategy is `EvolutionarySearch`,
then use `tvm.meta_schedule.cost_model.RandomModel`.
"""
# pylint: disable=import-outside-toplevel
from .cost_model import RandomModel
from .database import MemoryDatabase
from .search_strategy import EvolutionarySearch

# pylint: enable=import-outside-toplevel

if self.search_strategy is None:
raise ValueError(
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
if design_spaces is None:
design_spaces = self.generate_design_space()
if database is None:
if isinstance(self.search_strategy, EvolutionarySearch):
database = MemoryDatabase() # type: ignore
if cost_model is None:
if isinstance(self.search_strategy, EvolutionarySearch):
cost_model = RandomModel() # type: ignore
return self.search_strategy.pre_tuning(design_spaces, database, cost_model)
Comment thread
junrushao marked this conversation as resolved.

def post_tuning(self) -> None:
Expand All @@ -191,7 +246,7 @@ def post_tuning(self) -> None:
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
_ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member
return self.search_strategy.post_tuning()

def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]:
"""Generate a batch of measure candidates from design spaces for measurement.
Expand All @@ -208,7 +263,7 @@ def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]:
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member
return self.search_strategy.generate_measure_candidates()

def notify_runner_results(
self,
Expand All @@ -231,8 +286,4 @@ def notify_runner_results(
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
_ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member
self,
measure_candidates,
results,
)
return self.search_strategy.notify_runner_results(measure_candidates, results)
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def test_conv2d_winograd_cpu():
target,
),
)
context.initialize()
post_order_apply = context.space_generator
(sch,) = post_order_apply.generate_design_space(mod)
decisions = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def test_conv2d_winograd_cuda():
None, Target("cuda")
),
)
context.initialize()
post_order_apply = context.space_generator
(sch,) = post_order_apply.generate_design_space(mod)
decisions = dict(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def traverse(t):
mod,
target="llvm",
params=params,
filter_func=filter_func,
te_filter_func=filter_func,
)
expected_task_names = [
"fused_" + s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateComputeLocation(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator:
MutateParallel(max_jobs_per_core): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateThreadBinding(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def _make_mutator(target: Target) -> Mutator:
target=target,
mutator_probs={MutateTileSize(): 1.0},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateUnroll(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
6 changes: 0 additions & 6 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def test_meta_schedule_post_order_apply():
space_generator=PostOrderApply(),
sch_rules=[WowSoFancyScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 1
Expand All @@ -240,7 +239,6 @@ def test_meta_schedule_post_order_apply_double():
space_generator=PostOrderApply(),
sch_rules=[DoubleScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 2
Expand All @@ -258,7 +256,6 @@ def test_meta_schedule_post_order_apply_multiple():
space_generator=PostOrderApply(),
sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 4
Expand All @@ -276,7 +273,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
space_generator=PostOrderApply(),
sch_rules=[WowSoFancyScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
with pytest.raises(
TVMError,
Expand Down Expand Up @@ -348,7 +344,6 @@ def correct_trace(a, b, c, d):
space_generator=PostOrderApply(),
sch_rules=[RemoveBlock(), TrinityDouble()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 4
Expand Down Expand Up @@ -376,7 +371,6 @@ def test_meta_schedule_custom_search_space():
space_generator=PostOrderApply(),
sch_rules=[],
)
context.initialize()
post_order_apply = context.space_generator
post_order_apply.generate_design_space(mod)
called = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def _create_context(mod, target, postprocs):
postprocs=postprocs,
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Loading