Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/tirx/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ TVM_DLL Pass UnrollLoop();
TVM_DLL Pass RemoveNoOp();

/*!
* \brief Run arithmetic simplifications on the statements and expressions.
* \brief Run statement-level arithmetic simplifications on the TIR PrimFunc.
*
* \return The pass.
*/
TVM_DLL Pass Simplify();
TVM_DLL Pass StmtSimplify();

/*!
* \brief Convert an IRModule to be SSA form.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/s_tir/backend/adreno/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
s_tir.transform.LowerAutoCopy(),
s_tir.transform.UnifyThreadBinding(),
s_tir.transform.LowerMatchBuffer(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
s_tir.transform.InjectPermutedLayout(),
s_tir.transform.AnnotateIrregularLoop(),
s_tir.transform.InjectSoftwarePipeline(),
Expand All @@ -68,7 +68,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
s_tir.transform.HoistIfThenElse(),
tirx.transform.UnrollLoop(),
s_tir.transform.RenormalizeSplitPattern(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.RemoveNoOp(),
s_tir.transform.RewriteUnsafeSelect(),
]
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/s_tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
s_tir.transform.LowerAutoCopy(),
s_tir.transform.UnifyThreadBinding(),
s_tir.transform.LowerMatchBuffer(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
s_tir.transform.InjectPermutedLayout(),
s_tir.transform.AnnotateIrregularLoop(),
s_tir.transform.InjectSoftwarePipeline(),
Expand All @@ -68,7 +68,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
s_tir.transform.HoistIfThenElse(),
tirx.transform.UnrollLoop(),
s_tir.transform.RenormalizeSplitPattern(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.RemoveNoOp(),
s_tir.transform.RewriteUnsafeSelect(),
]
Expand Down Expand Up @@ -137,10 +137,10 @@ def finalize_host_passes(): # pylint: disable=unused-argument
def finalize_device_passes(): # pylint: disable=unused-argument
"""The default finalization passes for TIR backend."""
device_pass_list = [
tir.transform.LowerWarpMemory(),
tir.transform.Simplify(),
tir.transform.LowerCustomDatatypes(),
tir.transform.LowerIntrin(),
tirx.transform.LowerWarpMemory(),
tirx.transform.StmtSimplify(),
tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
return tvm.ir.transform.Sequential(device_pass_list)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ class object that inherits from `Exception`.
.. code-block:: python

class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tirx.transform.Simplify()
transform = tvm.tirx.transform.StmtSimplify()

def before(A: T.Buffer(1, "int32")):
if True:
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/tirx/compilation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
passes = [
tirx.transform.LowerInitBlock(),
tvm.s_tir.transform.UnifyThreadBinding(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.FlattenBuffer(),
tirx.transform.BF16ComputeLegalize(),
tirx.transform.NarrowDataType(32),
tirx.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))),
tirx.transform.UnrollLoop(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
]
if not bool(config.get("tir.disable_cse_tir", False)):
passes.append(tirx.transform.CommonSubexprElim())
Expand Down Expand Up @@ -73,14 +73,14 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
passes = [
tirx.transform.LowerTIRx(),
tvm.s_tir.transform.UnifyThreadBinding(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.LowerTIRxOpaque(),
tirx.transform.FlattenBuffer(),
tirx.transform.BF16ComputeLegalize(),
tirx.transform.NarrowDataType(32),
tirx.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))),
tirx.transform.UnrollLoop(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
]
if not bool(config.get("tir.disable_cse_tir", False)):
passes.append(tirx.transform.CommonSubexprElim())
Expand Down Expand Up @@ -115,11 +115,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
tirx.transform.trn.TrnNaiveAllocator(),
tirx.transform.LowerTIRx(),
tvm.s_tir.transform.DecorateDeviceScope(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.LowerTIRxOpaque(),
tvm.s_tir.transform.LoopPartition(),
tvm.s_tir.transform.HoistIfThenElse(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.RemoveNoOp(),
tirx.transform.AnnotateEntryFunc(),
tirx.transform.AnnotateDeviceRegions(),
Expand All @@ -146,7 +146,7 @@ def finalize_device_passes(): # pylint: disable=unused-argument
"""The default finalization passes for TIR backend."""
device_pass_list = [
tirx.transform.LowerWarpMemory(),
tirx.transform.Simplify(),
tirx.transform.StmtSimplify(),
tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
Expand All @@ -161,7 +161,7 @@ def finalize_device_passes_tirx(): # pylint: disable=unused-argument

def finalize_device_passes_trn(): # pylint: disable=unused-argument
"""The default finalization passes for TRN backend."""
device_pass_list = [tirx.transform.Simplify()]
device_pass_list = [tirx.transform.StmtSimplify()]
return tvm.ir.transform.Sequential(device_pass_list)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def impl():
import tvm

mod = tvm.IRModule({"main": impl})
mod = tvm.tirx.transform.Simplify()(mod)
mod = tvm.tirx.transform.StmtSimplify()(mod)
return mod["main"]
else:
# fmt: off
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/tirx/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,20 @@ def CommonSubexprElim():
return _ffi_api.CommonSubexprElim() # type: ignore


@_ffi.register_object("tirx.transform.SimplifyConfig")
class SimplifyConfig(_ffi.Object):
"""Config for simplify pass"""
@_ffi.register_object("tirx.transform.StmtSimplifyConfig")
class StmtSimplifyConfig(_ffi.Object):
"""Config for stmt simplify pass"""


def Simplify():
"""Run arithmetic simplifications on the statements and expressions.
def StmtSimplify():
"""Run statement-level arithmetic simplifications on the TIR PrimFunc.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.Simplify() # type: ignore
return _ffi_api.StmtSimplify() # type: ignore


def ConvertSSA():
Expand Down
224 changes: 0 additions & 224 deletions src/arith/narrow_predicate_expression.cc

This file was deleted.

Loading
Loading