diff --git a/docs/deep_dive/tensor_ir/index.rst b/docs/deep_dive/tensor_ir/index.rst index 95a6a3a402cc..2f8bd07c1b0c 100644 --- a/docs/deep_dive/tensor_ir/index.rst +++ b/docs/deep_dive/tensor_ir/index.rst @@ -39,3 +39,5 @@ In TVMScript, both modules are accessed via learning tutorials/tir_creation tutorials/tir_transformation + tutorials/dlight_gpu_scheduling + tutorials/meta_schedule diff --git a/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py b/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py new file mode 100644 index 000000000000..9c5fe1ff4c7c --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py @@ -0,0 +1,316 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa: E402, E501 + +""" +.. _dlight_gpu_scheduling: + +DLight: Rule-Based GPU Scheduling +================================== +TIR functions produced by Relax legalization need GPU-specific scheduling — thread binding, +loop tiling, shared memory usage — before they can run efficiently on a GPU. There are two +main approaches in TVM: + +- **MetaSchedule**: explores a search space to find the best schedule. High quality, but + compilation takes minutes to hours. +- **DLight**: applies pre-defined scheduling rules deterministically. No tuning required, + compilation completes in seconds. Performance is excellent for well-known patterns + (e.g., GEMM, GEMV in LLM workloads) and fair for the rest. + +This tutorial covers how DLight works, what rules are available, how to diagnose scheduling +quality, and how to write custom rules. + +.. contents:: Table of Contents + :local: + :depth: 1 +""" + +###################################################################### +# Prepare a Model +# --------------- +# We build a small model with ``nn.Module`` that is rich enough to trigger multiple DLight +# rules: ``Linear`` layers produce GEMM (matrix multiplication) kernels, ``LayerNorm`` +# produces a general-reduction kernel, and ``ReLU`` is a simple elementwise op. + +import tvm +from tvm import relax, tirx +from tvm.relax.frontend import nn +from tvm.s_tir import dlight as dl + + +class DemoModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(768, 768) + self.relu = nn.ReLU() + self.norm = nn.LayerNorm(768) + self.fc2 = nn.Linear(768, 256) + + def forward(self, x): + x = self.norm(self.relu(self.fc1(x))) + return self.fc2(x) + + +mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor((1, 768), "float32")}}) + +###################################################################### +# Legalize Relax operators into TIR functions so that DLight has concrete kernels to schedule. + +device = tvm.cuda(0) +target = tvm.target.Target.from_device(device) +with target: + mod = relax.get_pipeline("zero")(mod) + +###################################################################### +# At this point every TIR function in ``mod`` is **unscheduled** — it has no thread bindings +# and would not run efficiently on a GPU. Let's see what functions we have: +for gv, func in mod.functions_items(): + if isinstance(func, tirx.PrimFunc): + print(f" {gv.name_hint}") + +###################################################################### +# Basic Usage: ApplyDefaultSchedule +# --------------------------------- +# ``ApplyDefaultSchedule`` is an ``IRModule`` pass. It iterates over every TIR function in the +# module and tries the given rules **in order**. For each function the first rule whose +# ``apply()`` returns a non-``None`` schedule wins; subsequent rules are skipped. +# After scheduling, the function is marked with ``tirx.is_scheduled`` so it won't be +# scheduled again by a later ``ApplyDefaultSchedule`` call. + +###################################################################### +# Here we use a common subset of rules. The full catalog (including ``LowBatchGEMV``, +# ``Transpose``, ``RMSNorm``) is listed in the next section. + +with target: + scheduled_mod = dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), # GEMM: dense matrix multiplication + dl.gpu.GEMV(), # matrix-vector products + dl.gpu.Reduction(), # simple reductions (sum, max, ...) + dl.gpu.GeneralReduction(), # compound reductions (softmax, layer norm, ...) + dl.gpu.Fallback(), # catch-all for anything unmatched above + )(mod) + +scheduled_mod.show() + +###################################################################### +# Compared with the unscheduled IR, you can now see thread bindings +# (``blockIdx.x``, ``threadIdx.x``, ...) and loop transformations in each TIR function. + +###################################################################### +# Rule Catalog +# ------------ +# DLight ships a set of GPU scheduling rules. Each rule is a subclass of +# ``ScheduleRule`` and implements an ``apply(func, target, tunable)`` method that returns +# a ``Schedule`` if the rule matches, or ``None`` to pass. +# +# The built-in GPU rules, roughly from most specific to most general: +# +# .. list-table:: +# :header-rows: 1 +# :widths: 20 40 40 +# +# * - Rule +# - Pattern +# - Typical operators +# * - ``Matmul`` +# - GEMM index pattern ``C[S,I,J] += A[S,I,K] * B[S,J,K]`` +# - ``nn.Linear``, batched matmul +# * - ``GEMV`` +# - Matrix-vector multiply (one dimension is 1) +# - single-batch decode in attention +# * - ``LowBatchGEMV`` +# - Low-batch GEMM scheduled with a GEMV strategy +# - small-batch decode +# * - ``Reduction`` +# - Simple accumulation ``X[...] += Y[...]`` +# - sum, max, argmax +# * - ``GeneralReduction`` +# - Spatial dims followed by reduction dims (``S* R*``) +# - softmax, layer norm, RMS norm +# * - ``Transpose`` +# - Read/write indices are permutations of each other +# - 2-D transpose +# * - ``RMSNorm`` +# - Contains an ``rsqrt`` operation +# - RMS normalization +# * - ``Fallback`` +# - Any function (always matches) +# - generic catch-all +# +# **Rule order matters.** ``ApplyDefaultSchedule`` stops at the first match, so: +# +# - Put **specialized** rules first (``Matmul``, ``GEMV``) — they have strict matching +# conditions but produce high-quality schedules. +# - Put **general** rules later (``GeneralReduction``, ``Fallback``) — they match broadly +# but with less optimal schedules. +# - If you put ``Fallback`` first, it would "steal" every function and no specialized +# rule would ever run. + +###################################################################### +# Diagnosing Schedule Quality +# --------------------------- +# A common question is: *which rule scheduled which function?* ``ApplyDefaultSchedule`` +# does not log this directly, but you can figure it out by applying rules one at a time. +# +# **Step 1**: Apply each rule individually and record which functions it claims. + +from collections import OrderedDict + +rules = OrderedDict( + [ + ("Matmul", dl.gpu.Matmul()), + ("GEMV", dl.gpu.GEMV()), + ("LowBatchGEMV", dl.gpu.LowBatchGEMV()), + ("Reduction", dl.gpu.Reduction()), + ("GeneralReduction", dl.gpu.GeneralReduction()), + ("Transpose", dl.gpu.Transpose()), + ("RMSNorm", dl.gpu.RMSNorm()), + ] +) + +rule_assignment = {} +for rule_name, rule in rules.items(): + with target: + test_mod = dl.ApplyDefaultSchedule(rule)(mod) + for gv, func in test_mod.functions_items(): + if isinstance(func, tirx.PrimFunc) and gv.name_hint not in rule_assignment: + if "tirx.is_scheduled" in func.attrs and func.attrs["tirx.is_scheduled"] == 1: + rule_assignment[gv.name_hint] = rule_name + +###################################################################### +# **Step 2**: Functions not claimed by any specialized rule will fall through to ``Fallback``. + +all_tir_funcs = [ + gv.name_hint for gv, func in mod.functions_items() if isinstance(func, tirx.PrimFunc) +] +fallback_funcs = [name for name in all_tir_funcs if name not in rule_assignment] + +print("Rule assignments:") +for name, rule_name in sorted(rule_assignment.items()): + print(f" {name:40s} -> {rule_name}") +if fallback_funcs: + print("Handled by Fallback (may have suboptimal performance):") + for name in sorted(fallback_funcs): + print(f" {name}") + +###################################################################### +# If an important kernel lands in the Fallback bucket, you have three options: +# +# 1. Write a **custom DLight rule** for it (see below). +# 2. Use **MetaSchedule** to auto-tune that specific function. +# 3. Manually schedule it with the ``tvm.s_tir.Schedule`` API. + +###################################################################### +# DLight vs MetaSchedule +# ---------------------- +# The two systems are complementary, not competing: +# +# .. list-table:: +# :header-rows: 1 +# :widths: 20 40 40 +# +# * - +# - DLight +# - MetaSchedule +# * - Mechanism +# - Deterministic rule matching +# - Search-space exploration +# * - Compile time +# - Seconds +# - Minutes to hours +# * - Performance +# - Excellent on known patterns, fair otherwise +# - Near-optimal with sufficient search budget +# * - Best for +# - Default path, rapid iteration, CI +# - Hot-spot tuning in production +# +# A practical workflow: +# +# 1. Run ``ApplyDefaultSchedule`` with the full rule set to cover all functions. +# 2. Profile the compiled model to identify hot-spot kernels. +# 3. Use ``MetaScheduleTuneTIR`` to auto-tune only those kernels. +# +# Note that ``MetaScheduleTuneTIR`` does **not** automatically skip functions already +# scheduled by DLight — it processes every ``PrimFunc`` in the module. In practice this +# is harmless (tuning an already-scheduled function simply re-explores its space), but if +# you want to avoid the extra search cost, filter the module or use ``MetaScheduleTuneIRMod`` +# with ``op_names`` to target specific functions. + +###################################################################### +# Writing a Custom Rule +# --------------------- +# You can extend DLight by writing your own ``ScheduleRule``. The simplest way is +# ``ScheduleRule.from_callable``, which wraps a plain function into a rule **instance**. + +from tvm import s_tir +from tvm.s_tir.dlight.analysis import normalize_prim_func +from tvm.s_tir.dlight.base.schedule_rule import ScheduleRule + + +@ScheduleRule.from_callable("MyTileAndBind") +def my_tile_and_bind(func: tirx.PrimFunc, target: tvm.target.Target, tunable: bool): + """A minimal rule: for single-block injective functions, tile and bind to GPU threads.""" + if not isinstance(func, tirx.PrimFunc): + return None + sch = s_tir.Schedule(func) + # Use normalize_prim_func to get block info with correct spatial/reduction classification. + # This is the same analysis used by built-in DLight rules. + block_infos = normalize_prim_func(sch) + if block_infos is None or len(block_infos) != 1: + return None # only handle single-block functions + info = block_infos[0] + if not info.is_injective(): + return None # skip reductions — dom_kind() uses iter_type, not loop kind + loops = sch.get_loops(info.block_rv) + if len(loops) == 0: + return None + fused = sch.fuse(*loops) + bx, tx = sch.split(fused, factors=[None, 256]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + return sch + + +###################################################################### +# Insert the custom rule into the rule chain. Note that ``from_callable`` returns an +# **instance**, so pass it directly — do not call ``my_tile_and_bind()`` again. + +with target: + custom_mod = dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GeneralReduction(), + my_tile_and_bind, # our custom rule, tried before Fallback + dl.gpu.Fallback(), + )(mod) + +custom_mod.show() + +###################################################################### +# To build a production-quality rule, subclass ``ScheduleRule`` directly and implement +# ``apply()`` with full analysis logic (see ``tvm.s_tir.dlight.gpu.Matmul`` for an example). + +###################################################################### +# Summary +# ------- +# - **DLight** provides fast, deterministic GPU scheduling via rule matching. +# - Rules are tried in order; the first match wins. Put specialized rules before general ones. +# - Use the **single-rule probing** technique to diagnose which rule handles each function. +# - Combine DLight with MetaSchedule: DLight for baseline coverage, MetaSchedule for hot-spot tuning. +# - Extend DLight by writing custom ``ScheduleRule`` implementations. +# +# For DLight's role in the broader optimization pipeline, see :ref:`customize_opt`. diff --git a/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py b/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py new file mode 100644 index 000000000000..a263397bbe2a --- /dev/null +++ b/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa: E402 + +""" +.. _meta_schedule_deep_dive: + +MetaSchedule: Search-Based Auto-Tuning +======================================= +MetaSchedule is TVM's search-based auto-tuning framework, located in +``python/tvm/s_tir/meta_schedule/``. It explores different TIR schedules +(loop tiling, vectorization, thread binding, etc.) and measures them on real +hardware to find the fastest implementation for each operator. + +While **DLight** (see :ref:`dlight_gpu_scheduling`) provides rule-based scheduling with zero +search time, MetaSchedule trades compilation time for better performance by searching over +the space of possible schedules. + +.. contents:: Table of Contents + :local: + :depth: 1 +""" + +###################################################################### +# Architecture Overview +# --------------------- +# A MetaSchedule tuning session involves the following components: +# +# - **ExtractedTask**: A unique TIR workload extracted from a Relax IRModule, +# with a ``task_name`` and ``weight`` (call frequency in the graph). +# - **TuneContext**: Container holding all resources for a single tuning task +# (module, target, space generator, search strategy, etc.). +# - **SpaceGenerator** (default: ``PostOrderApply``): Generates the design space +# of possible schedules by applying ``ScheduleRule`` instances to each block. +# - **SearchStrategy** (default: ``EvolutionarySearch``): Explores the design +# space using an evolutionary algorithm guided by a cost model. +# - **CostModel** (default: ``XGBModel``): Predicts schedule performance using +# XGBoost, reducing the number of actual hardware measurements needed. +# Alternatives include ``MLPModel`` (neural network) and ``RandomModel`` +# (baseline). +# - **Builder** / **Runner**: Compile and execute candidates on real hardware to +# obtain measured run times. +# - **Database** (default: ``JSONDatabase``): Persistently stores tuning records +# (schedule traces + measured run times) for later retrieval. +# - **TaskScheduler** (default: ``GradientBasedScheduler``): Allocates tuning +# budget across multiple tasks based on their weights and estimated improvement +# potential. +# +# The tuning loop works as follows: +# +# 1. The **TaskScheduler** picks a task to tune. +# 2. The **SpaceGenerator** produces candidate schedules from the design space. +# 3. The **SearchStrategy** selects candidates (guided by the **CostModel**), +# sends them to the **Builder** and **Runner** for measurement. +# 4. Measured results are committed to the **Database** and used to update the +# **CostModel** for the next iteration. +# 5. Repeat until the trial budget is exhausted. + +###################################################################### +# Prepare a Model +# --------------- +# We reuse a simple model to demonstrate MetaSchedule APIs. + +import os +import tempfile + +import tvm +from tvm import relax +from tvm.relax.frontend import nn + + +class DemoModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 256) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(256, 10, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + +input_shape = (1, 784) +mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}) + +device = tvm.cuda(0) +target = tvm.target.Target.from_device(device) + +###################################################################### +# User-Facing Entry Points +# ------------------------ +# MetaSchedule provides several levels of API, from high-level transforms to +# low-level tuning functions. +# +# Transform-Based API (Recommended) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# These are Relax passes that can be composed into a ``Sequential`` pipeline: +# +# - **MetaScheduleTuneIRMod**: Tunes an entire IRModule. Supports ``op_names`` +# for selective operator tuning. +# - **MetaScheduleTuneTIR**: Tunes all TIR functions individually (no +# ``op_names`` filtering). +# - **MetaScheduleApplyDatabase**: Applies the best schedules from the tuning +# database. Only replaces functions that have records; the rest are left +# unchanged. +# +# Here is a typical tune-and-apply pipeline: +# +# .. note:: +# +# To save CI time and avoid flakiness, we skip the tuning process in CI. + +if os.getenv("CI", "") != "true": + with target, tempfile.TemporaryDirectory() as tmp_dir: + tuned_mod = tvm.ir.transform.Sequential( + [ + relax.get_pipeline("zero"), + relax.transform.MetaScheduleTuneTIR( + work_dir=tmp_dir, + max_trials_global=300, + ), + relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir), + ] + )(mod) + + tuned_mod.show() + +###################################################################### +# Inspecting Tunable Tasks +# ------------------------ +# Before tuning, use ``extract_tasks`` to see what MetaSchedule will tune: + +from tvm.s_tir.meta_schedule.relax_integration import extract_tasks + +with target: + legalized_mod = relax.get_pipeline("zero")(mod) + +tasks = extract_tasks(legalized_mod, target) +for i, task in enumerate(tasks): + print(f"Task {i}: {task.task_name} (weight={task.weight})") + +###################################################################### +# Each ``ExtractedTask`` has: +# +# - ``task_name``: Derived from the PrimFunc name (e.g., ``"fused_matmul_add_relu"``). +# - ``weight``: How many ``call_tir`` sites invoke this workload. The task +# scheduler uses weights to allocate more budget to frequently-called operators. +# - ``dispatched``: List of candidate TIR modules for this workload. + +###################################################################### +# Selective Operator Tuning +# ------------------------- +# ``MetaScheduleTuneIRMod`` accepts an ``op_names`` parameter to tune only +# operators whose task name contains any of the given strings: +# +# .. code-block:: python +# +# with target: +# mod = tvm.ir.transform.Sequential([ +# relax.transform.MetaScheduleTuneIRMod( +# params={}, +# work_dir="./tuning_logs", +# max_trials_global=300, +# op_names=["matmul"], # Only tune matmul-related operators +# ), +# relax.transform.MetaScheduleApplyDatabase(work_dir="./tuning_logs"), +# ])(mod) +# +# Operators without tuning records are left unscheduled -- you can apply DLight or +# other rule-based schedules to cover them afterward. +# +# .. note:: +# +# ``MetaScheduleTuneTIR`` does not support ``op_names`` filtering. Use +# ``MetaScheduleTuneIRMod`` when you need selective tuning. + +###################################################################### +# Database +# -------- +# When using a fixed ``work_dir``, tuning results are persisted in two +# newline-delimited JSON files: +# +# - ``database_workload.json``: One line per unique workload (structural hash + +# serialized IRModule). +# - ``database_tuning_record.json``: One line per tuning record (workload index + +# schedule trace + measured run times). +# +# Records are appended incrementally as tuning progresses. +# +# Resumption Semantics +# ~~~~~~~~~~~~~~~~~~~~ +# When you re-run tuning with the same ``work_dir``, existing records are loaded +# and used as warm-start seeds for the evolutionary search. The tuner does +# **not** skip already-seen workloads entirely -- it starts from a better initial +# population, so re-runs are faster than starting from scratch but still consume +# trials. +# +# Once tuning is done, subsequent compilations only need +# ``MetaScheduleApplyDatabase``: +# +# .. code-block:: python +# +# with target: +# mod = relax.transform.MetaScheduleApplyDatabase( +# work_dir="./tuning_logs" +# )(mod) +# +# Database Implementations +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# MetaSchedule ships several database backends: +# +# - **JSONDatabase**: Persistent file-based storage (default). Created +# automatically when you pass ``work_dir``. +# - **MemoryDatabase**: In-memory, non-persistent. Useful for testing. +# - **UnionDatabase**: Queries all sub-databases and returns the globally best +# record. +# - **OrderedUnionDatabase**: Queries sub-databases in order; returns from the +# first one that has a match. +# - **ScheduleFnDatabase**: Wraps a user-provided scheduling function. + +###################################################################### +# Cross-Model Database Reuse +# -------------------------- +# MetaSchedule identifies workloads by their structural hash. If two models +# contain operators with the same shape, dtype, and computation, they share the +# same hash and can reuse tuning records. +# +# module_equality Options +# ~~~~~~~~~~~~~~~~~~~~~~~ +# - ``"structural"`` (default): Exact structural match. Safe but strict. +# - ``"anchor-block"``: Match based on the dominant compute block, ignoring +# surrounding context. More permissive -- enables sharing across fused operators +# that have the same core computation but different fusion boundaries. +# +# ``OrderedUnionDatabase`` enables a layered lookup strategy: check a local +# database first, then fall back to a shared team database: +# +# .. code-block:: python +# +# from tvm.s_tir.meta_schedule.database import JSONDatabase, OrderedUnionDatabase +# +# local_db = JSONDatabase(work_dir="./my_tuning_logs") +# shared_db = JSONDatabase(work_dir="/shared/tuning_db") +# combined_db = OrderedUnionDatabase(local_db, shared_db) +# +# with target, combined_db: +# mod = relax.transform.MetaScheduleApplyDatabase()(mod) + +###################################################################### +# Key Parameters Reference +# ------------------------ +# +# .. list-table:: +# :header-rows: 1 +# :widths: 25 75 +# +# * - Parameter +# - Description +# * - ``max_trials_global`` +# - Total trial budget shared across all tasks. Set proportional to the +# number of tasks (e.g., 200-500 trials per task for good results). +# * - ``max_trials_per_task`` +# - Per-task trial cap. Defaults to ``max_trials_global`` if not set. +# * - ``op_names`` +# - List of strings to filter tasks by name (substring match). +# ``MetaScheduleTuneIRMod`` only. +# * - ``work_dir`` +# - Directory for database files and logs. Use a fixed path to enable +# persistence and resumption. +# * - ``cost_model`` +# - ``"xgb"`` (XGBoost, default), ``"mlp"`` (neural network), or +# ``"random"`` (baseline). Only available via ``tune_relax``. +# * - ``runner`` +# - ``"local"`` (default) or an ``RPCRunner`` instance for remote devices. +# Only available via ``tune_relax``. +# * - ``module_equality`` +# - ``"structural"`` (default) or ``"anchor-block"`` for more permissive +# cross-model matching. Only available via ``tune_relax``. + +###################################################################### +# Summary +# ------- +# - **MetaSchedule** finds high-quality TIR schedules by searching over the +# design space and measuring on real hardware. +# - Use ``MetaScheduleTuneTIR`` for full-module tuning, or +# ``MetaScheduleTuneIRMod`` with ``op_names`` for selective tuning. +# - Tuning records persist in ``work_dir`` and can be reused across runs and +# models with the same operator shapes. +# - Combine with DLight: use DLight for fast baseline coverage, then MetaSchedule +# for hot-spot tuning (see :ref:`dlight_gpu_scheduling`).