From b25bd0b9d9c50863dc7598ab1a7d3db4e5f30e9c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 2 Jul 2023 21:52:10 -0700 Subject: [PATCH] [Dlight] Enhance Decode-GEMV Rules This PR enhances Decode-GEMV rule with the following changes: - Normalize the GEMV iter domain to S-R-C via transform-block-layout. This would help with further analysis and scheduling, in cases for example, when there was no spatial loop in the original reduction block. - Get rid of the ad hoc iter type analysis, including the logic calling into a TVM packed func `tir.schedule.GetLoopIterType` using `tvm._ffi.get_global_func`. - Split out the logic for two separate cases of scheduling, where the innermost dimension is spatial or reduction. - Introduces `suggest_threads_per_block` to guess the threads to be allocated each threadblock. This helps avoid the previous case where dlight allocates 256 threads for a workload whose degree of parallelism is only 128. - Misc improvements. This rest of the changes are split out to separate PRs that are already merged to main. - [x] Pass the hints to arithmetic analyzer that shape variables should be positive ones (#15210) - [x] Eliminate unnecessary block predicate generation - should be provable via affine analysis (#15193) - [x] Shrink local memory allocation if only one element `X[threadIdx.x]` is used (#15207) --- pyproject.toml | 4 + python/tvm/dlight/base/analysis.py | 3 +- python/tvm/dlight/gpu/__init__.py | 4 +- python/tvm/dlight/gpu/decode_gemv.py | 257 ++++++++++++-------- python/tvm/dlight/gpu/fallback.py | 5 +- python/tvm/dlight/gpu/matmul.py | 4 +- python/tvm/dlight/gpu/utils.py | 87 +++++++ tests/python/dlight/test_gpu_decode_gemv.py | 53 ++-- 8 files changed, 275 insertions(+), 142 deletions(-) create mode 100644 python/tvm/dlight/gpu/utils.py diff --git a/pyproject.toml b/pyproject.toml index 5cca711ddbe6..e984b41b11a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +[tool.isort] +profile = "black" +src_paths = ["python", "tests/python"] + [tool.black] line-length = 100 diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/base/analysis.py index d11e29a8ad59..2607968ef27b 100644 --- a/python/tvm/dlight/base/analysis.py +++ b/python/tvm/dlight/base/analysis.py @@ -17,13 +17,12 @@ """Analysis on TIR blocks, loops and functions.""" from typing import List, Optional, Union -from typing_extensions import Literal - from tvm import tir from tvm._ffi import get_global_func from tvm.target.target import Target from tvm.tir import Schedule from tvm.tir.schedule import BlockRV +from typing_extensions import Literal class IterInfo: diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py index 79090d400b42..934928ffafc9 100644 --- a/python/tvm/dlight/gpu/__init__.py +++ b/python/tvm/dlight/gpu/__init__.py @@ -18,7 +18,7 @@ GPU-generic schedule rules. For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ -from .fallback import Fallback from .decode_gemv import DecodeGEMV -from .reduction import Reduction +from .fallback import Fallback from .matmul import Matmul +from .reduction import Reduction diff --git a/python/tvm/dlight/gpu/decode_gemv.py b/python/tvm/dlight/gpu/decode_gemv.py index 18395b8063f2..b9e8b44ef2b5 100644 --- a/python/tvm/dlight/gpu/decode_gemv.py +++ b/python/tvm/dlight/gpu/decode_gemv.py @@ -14,19 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring -"""A fallback schedule rule for GPU operators.""" -# pylint: disable=invalid-name +"""A rule for DecodeGEMV.""" +from typing import List, Optional, Set, Tuple, Union -from typing import List, Optional, Union - -from tvm import tir -from tvm._ffi import get_global_func -from tvm.arith import normalize_to_iter_sum +from tvm import arith, tir from tvm.ir import structural_equal from tvm.target import Target -from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial +from ..base import ( + BlockInfo, + ScheduleRule, + normalize_prim_func, + try_inline_contiguous_spatial, +) +from . import utils def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr: dominant_read, read_iters = None, None - tir_vars = set() + tir_vars: Set[tir.Var] = set() for buffer_region in block.reads: tir_vars.clear() - def _collect_tir_var(e): - if isinstance(e, tir.Var): - tir_vars.add(e) + def _collect_tir_var(expr): + if isinstance(expr, tir.Var): + tir_vars.add(expr) for expr in buffer_region.region: assert expr.extent == 1 @@ -68,11 +69,9 @@ def _collect_tir_var(e): class DecodeGEMV(ScheduleRule): - def __init__(self) -> None: - super().__init__() - self.get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType") + """A rule for DecodeGEMV.""" - def apply( # pylint: disable=too-many-locals + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, func: tir.PrimFunc, target: Target, @@ -80,15 +79,8 @@ def apply( # pylint: disable=too-many-locals ) -> Union[None, tir.Schedule, List[tir.Schedule]]: if not isinstance(func, tir.PrimFunc): return None - - if target.kind.name == "cuda": - len_tx, len_ty = 16, 16 - else: - len_tx, len_ty = 8, 8 - sch = tir.Schedule(func) block_infos = try_inline_contiguous_spatial(sch, normalize_prim_func(sch)) - if block_infos is None or len(block_infos) > 2: return None @@ -97,96 +89,145 @@ def apply( # pylint: disable=too-many-locals block_stmt = sch.get(block) # Step 1. Check reduction block - if not block_info.is_reduction(): + if ( + (not block_info.is_reduction()) + or len(block_stmt.writes) != 1 + or _get_reduction_expr(block_stmt) is None + ): return None - if len(block_stmt.writes) != 1: - return None - if _get_reduction_expr(block_stmt) is None: - return None - - # Step 2. Sort out the spatial and reduction loops - sorted_iter_access = normalize_to_iter_sum( - _detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + # Step 2. Normalize the block, merge spatial and reduction iters + is_inner_reduction, c_factor = self._normalize( + sch, + block_info, + arith.normalize_to_iter_sum( + _detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ), ) - if sorted_iter_access.base != 0: - return None - iter_to_info = {i.var: i for i in block_info.iters} - s_loops, r_loops, c_loops = [], [], [] - for split in sorted_iter_access.args: - block_var = split.source.source - block_var_info = iter_to_info[block_var] - loop_rv = block_var_info.loop_rv - is_inner_reduction = block_var_info.kind == "R" - if split.lower_factor > 1: - c_loop_factor = split.lower_factor - loop_rv, c_loop = sch.split(loop_rv, factors=[None, c_loop_factor]) - c_loops.append(c_loop) - is_loop_c_reduction = is_inner_reduction - if is_inner_reduction: - r_loops.append(loop_rv) - else: - s_loops.append(loop_rv) - - if len(c_loops) > 1: - return None - if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]): + if is_inner_reduction is None and c_factor is None: return None - if len(s_loops) == 0 or len(r_loops) == 0: - return None - - sch.reorder(*s_loops, *r_loops, *c_loops) - s = sch.fuse(*s_loops) - r = sch.fuse(*r_loops) - - if is_inner_reduction: - _, tx = sch.split(r, factors=[None, len_tx * len_ty]) - rf = sch.rfactor(tx, 0) - s, r, tx = sch.get_loops(rf)[:3] - sch.reorder(s, tx, r) - sch.reverse_compute_at(block, s, preserve_unit_loops=True) - sch.bind(tx, "threadIdx.x") - sch.bind(s, "blockIdx.x") - else: - sch.split(s, factors=[None, len_tx]) - _, ty = sch.split(r, factors=[None, len_ty]) - rf = sch.rfactor(ty, 0) - bx, tx, r, ty = sch.get_loops(rf)[:4] - sch.reorder(bx, tx, ty, r) - sch.reverse_compute_at(block, bx, preserve_unit_loops=True) - sch.bind(tx, "threadIdx.x") - sch.bind(ty, "threadIdx.y") - sch.bind(bx, "blockIdx.x") - - s_loops, r_loops = [], [] - for loop_rv in sch.get_loops(block)[1:]: - iter_type = self.get_loop_iter_type(sch, loop_rv) - if iter_type == "S": - s_loops.append(loop_rv) - elif iter_type == "R": - r_loops.append(loop_rv) - else: - raise RuntimeError("Unknown loop type " + str(iter_type)) - sch.reorder(*s_loops, *r_loops) - s_ctr = sch.fuse(*s_loops) - r_ctr = sch.fuse(*r_loops) - - if c_loops and not is_loop_c_reduction: - s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor]) - sch.reorder(s_ctr, r_ctr, inner) - + # Step 3. Do the scheduling if is_inner_reduction: - sch.bind(r_ctr, "threadIdx.x") - sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, sch.get_loops(rf)[2]) + self._sch_inner_reduction(sch, target, block, c_factor) else: - sch.bind(s_ctr, "threadIdx.x") - sch.bind(r_ctr, "threadIdx.y") - sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, sch.get_loops(rf)[3]) - + self._sch_inner_spatial(sch, target, block, c_factor) + # Step 4. Schedule epilogue if len(block_infos) == 2: sch.set_scope(block, 0, "local") sch.reverse_compute_at(block_infos[1].block_rv, sch.get_loops(block)[0]) - return sch + + def _normalize( + self, + sch: tir.Schedule, + block_info: BlockInfo, + iter_sum: arith.IterSumExpr, + ) -> Tuple[Optional[bool], Optional[int]]: + if iter_sum.base != 0: + return None, None + iter_to_info = {i.var: i for i in block_info.iters} + s_dom, r_dom, c_dom, c_factor = None, None, None, None + for split in iter_sum.args: + var = split.source.source + info = iter_to_info[var] + dom = info.dom + is_inner_reduction = info.kind == "R" + if split.lower_factor > 1: + if c_dom is not None: + return None, None + c_dom = tir.floormod(var, split.lower_factor) + var = tir.floordiv(var, split.lower_factor) + dom = tir.floordiv(dom, split.lower_factor) + if not is_inner_reduction: + c_factor = split.lower_factor + if is_inner_reduction: + if r_dom is None: + r_dom = var + else: + r_dom = r_dom * dom + var + else: + if s_dom is None: + s_dom = var + else: + s_dom = s_dom * dom + var + + assert r_dom is not None + if s_dom is None: + s_dom = tir.const(1, r_dom.dtype) + if c_dom is None: + c_dom = tir.const(1, r_dom.dtype) + sch.transform_block_layout( + block_info.block_rv, + tir.IndexMap( + [i.var for i in block_info.iters], + [s_dom, r_dom, c_dom], + None, + ), + ) + return is_inner_reduction, c_factor + + def _sch_inner_reduction( + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + unroll_spatial_factor: Optional[int], + ): + # pylint: disable=invalid-name + _, r, _ = sch.get_loops(block) + (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking + target, [sch.get(r)] + ) + + _, tx = sch.split(r, factors=[None, len_tx]) + # Schedule the RF block + rf = sch.rfactor(tx, 0) + bx, r, tx, _ = sch.get_loops(rf) + sch.reorder(bx, tx, r) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, tx, *s = sch.get_loops(block) + s = sch.fuse(*s) + sch.reorder(s, tx) + if unroll_spatial_factor: + s, inner = sch.split(s, factors=[None, unroll_spatial_factor]) + sch.reorder(s, tx, inner) + sch.bind(tx, "threadIdx.x") + # pylint: enable=invalid-name + + def _sch_inner_spatial( + self, + sch: tir.Schedule, + _: Target, + block: tir.schedule.BlockRV, + unroll_spatial_factor: Optional[int], + ): + # pylint: disable=invalid-name + s, r, _ = sch.get_loops(block) + len_tx, len_ty = 16, 16 + _, _ = sch.split(s, factors=[None, len_tx]) + _, ty = sch.split(r, factors=[None, len_ty]) + # Schedule the RF block + rf = sch.rfactor(ty, 0) + bx, tx, r, ty, _ = sch.get_loops(rf) + sch.reorder(bx, tx, ty, r) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, r, *s = sch.get_loops(block) + s = sch.fuse(*s) + sch.reorder(s, r) + if unroll_spatial_factor: + s, inner = sch.split(s, factors=[None, unroll_spatial_factor]) + sch.reorder(s, r, inner) + sch.bind(s, "threadIdx.x") + sch.bind(r, "threadIdx.y") + # pylint: enable=invalid-name diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 6b120b16488b..14b74887afb8 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -21,7 +21,8 @@ from tvm import tir from tvm.target import Target -from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline +from ..base import ScheduleRule, normalize_prim_func, try_inline +from . import utils class Fallback(ScheduleRule): @@ -36,7 +37,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> tir.Schedule: - max_threads_per_block = analysis.get_max_threads_per_block(target) + max_threads_per_block = utils.max_threads_per_block(target) sch = tir.Schedule(func) block_infos = try_inline(sch, normalize_prim_func(sch)) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index e66eaa32226b..86d685e53c84 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -16,14 +16,14 @@ # under the License. # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" -from enum import Enum from dataclasses import dataclass +from enum import Enum from typing import Dict, List, Optional, Set, Tuple from tvm import tir from tvm.ir import Range from tvm.target import Target -from tvm.tir import PrimExpr, Var, IterVar +from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py new file mode 100644 index 000000000000..4fcc76294276 --- /dev/null +++ b/python/tvm/dlight/gpu/utils.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""Utility methods for generic GPU.""" +from typing import List, Optional + +from tvm import tir +from tvm.target import Target + + +def max_threads_per_block(target: Target) -> int: + """Get the maximum number of threads per block for a given target. + + Parameters + ---------- + target : Target + The target to get the maximum number of threads per block for. + + Returns + ------- + max_threads_per_block : int + The maximum number of threads per block for the given target. + """ + for name in ["max_threads_per_block", "max_num_threads"]: + result = target.attrs.get(name, None) + if result is not None: + return result + if target.kind.name == "cuda": + return 1024 + return 256 + + +def suggest_threads_per_block( + target: Target, + loops: List[tir.For], + max_threads_for_dynamic_loop: int = 32, +) -> List[int]: + if target.kind.name == "cuda": + threads = 256 + else: + threads = 64 + results: List[Optional[int]] = [] + dynamic: List[int] = [] + for i, loop in enumerate(loops): + loop_extent = loop.extent + if isinstance(loop_extent, tir.IntImm): + loop_extent = loop_extent.value + extent = 1 + while extent <= loop_extent and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results.append(extent) + else: + results.append(None) + dynamic.append(i) + + for i in dynamic: + extent = 1 + while extent <= max_threads_for_dynamic_loop and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results[i] = extent + + if dynamic: + results[dynamic[0]] *= threads + + return results diff --git a/tests/python/dlight/test_gpu_decode_gemv.py b/tests/python/dlight/test_gpu_decode_gemv.py index 46232a461eb5..303b16809e7b 100644 --- a/tests/python/dlight/test_gpu_decode_gemv.py +++ b/tests/python/dlight/test_gpu_decode_gemv.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring -import tvm +# pylint: disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals from tvm import dlight as dl from tvm.ir import assert_structural_equal from tvm.script import ir as I @@ -31,7 +30,6 @@ class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - # with T.block("root"): B = T.alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): with T.block("decode"): @@ -66,8 +64,8 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") with T.block("matmul_rf_update"): vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1) v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", [i2_i0_i1_fused, k_0_fused_0, k_1]) - C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) // 32]) - for ax1_ax2_ax3_fused in range(1): + C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + vk_0_fused_1) // 4]) + for ax1_ax2_ax3_fused in range(1): # pylint: disable=unused-variable for ax0_fused in T.thread_binding(256, thread="threadIdx.x"): with T.block("matmul"): vk_0_fused_1 = T.axis.reduce(256, ax0_fused) @@ -128,7 +126,7 @@ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16") vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1) v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + i2_i0_i1_fused_1) vk_0_fused_0, vk_1 = T.axis.remap("RR", [k_0_fused_0, k_1]) - C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 32, v_i2]) + C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 16 + vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_0_fused_0 * 16 + vk_0_fused_1, v_i2], T.Cast("uint32", ((vk_0_fused_0 * 16 + vk_0_fused_1) * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 16 + vk_0_fused_1) // 4, v_i2]) for ax1_ax2_ax3_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): with T.block("matmul"): @@ -184,23 +182,26 @@ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16") for i2_1_init in range(8): with T.block("matmul_rf_init"): vk_fused_1 = T.axis.spatial(256, k_fused_1) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + i2_1_init) - C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0) + v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused) + v_i2 = T.axis.spatial(8, i2_1_init) + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = T.float16(0) for k_fused_0, i2_1 in T.grid(16, 8): with T.block("matmul_rf_update"): vk_fused_1 = T.axis.spatial(256, k_fused_1) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + i2_1) + v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused) vk_fused_0 = T.axis.reduce(16, k_fused_0) - C_rf_local[vk_fused_1, 0, 0, v_i2] = C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 + vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1]) + v_i2 = T.axis.spatial(8, i2_1) + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i1, vk_fused_0 * 256 + vk_fused_1], T.Cast("uint32", (v_i1 * 8 + v_i2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i1 // 4, vk_fused_0 * 256 + vk_fused_1]) for ax1_ax2_ax3_fused_0 in range(1): for ax0_fused in T.thread_binding(256, thread="threadIdx.x"): for ax1_ax2_ax3_fused_1 in range(8): with T.block("matmul"): vk_fused_1 = T.axis.reduce(256, ax0_fused) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) + v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused) + v_i2 = T.axis.spatial(8, ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) with T.init(): - C[0, 0, v_i2] = T.float16(0) - C[0, 0, v_i2] = C[0, 0, v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i2] + C[0, 0, v_i1 * 8 + v_i2] = T.float16(0) + C[0, 0, v_i1 * 8 + v_i2] = C[0, 0, v_i1 * 8 + v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] # fmt: on @@ -241,7 +242,6 @@ class After: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for i2_0_i0_i1_fused_0 in T.thread_binding(32, thread="blockIdx.x"): for i2_0_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): @@ -249,23 +249,26 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") for i2_1_init in range(8): with T.block("matmul_rf_init"): vk_fused_1 = T.axis.spatial(16, k_fused_1) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init) - C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0) + v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 16 + i2_0_i0_i1_fused_1) + v2 = T.axis.spatial(8, i2_1_init) + C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = T.float16(0) for k_fused_0, i2_1 in T.grid(256, 8): with T.block("matmul_rf_update"): vk_fused_1 = T.axis.spatial(16, k_fused_1) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1) + v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 16 + i2_0_i0_i1_fused_1) vk_fused_0 = T.axis.reduce(256, k_fused_0) - C_rf_local[vk_fused_1, 0, 0, v_i2] = C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32]) + v2 = T.axis.spatial(8, i2_1) + C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, v1], T.Cast("uint32", (v1 * 8 + v2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v1 // 4]) for ax1_ax2_ax3_fused_0 in T.thread_binding(16, thread="threadIdx.x"): for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): for ax1_ax2_ax3_fused_1 in range(8): with T.block("matmul"): vk_fused_1 = T.axis.reduce(16, ax0_fused) - v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) + v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 16 + (ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) // 8) + v2 = T.axis.spatial(8, (ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) % 8) with T.init(): - C[0, 0, v_i2] = T.float16(0) - C[0, 0, v_i2] = C[0, 0, v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i2] + C[0, 0, v1 * 8 + v2] = T.float16(0) + C[0, 0, v1 * 8 + v2] = C[0, 0, v1 * 8 + v2] + C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] # fmt: on @@ -325,8 +328,8 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") with T.block("matmul_rf_update"): vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1) v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", [i2_i0_i1_fused, k_0_fused_0, k_1]) - C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) // 32]) - for ax1_ax2_ax3_fused in range(1): + C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + vk_0_fused_1) // 4]) + for ax1_ax2_ax3_fused in range(1): # pylint: disable=unused-variable for ax0_fused in T.thread_binding(256, thread="threadIdx.x"): with T.block("matmul"): vk_0_fused_1 = T.axis.reduce(256, ax0_fused) @@ -397,9 +400,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") for ax1_0_fused_0, ax1_1 in T.grid(2, 8): with T.block("matmul_rf_update"): vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 = T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1]) - T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0], V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0, (vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, (vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32]) - T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]) - C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32]) + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, (vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 + vax1_1]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, vax1_0_fused_0 * 256 + vax1_0_fused_1], T.Cast("uint32", ((vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1) // 4]) for ax1_fused in range(1): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("matmul"):