diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index d97bd647cda5..3fc310c47e34 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -22,6 +22,7 @@ set -o pipefail # install libraries for python package on ubuntu pip3 install --upgrade \ + "Pygments>=2.4.0" \ attrs \ cloudpickle \ cython \ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index b94d50dbf20d..06537e2cdc4d 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" +from typing import Optional + from tvm._ffi.base import string_types import tvm._ffi @@ -276,6 +278,19 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: self, tir_prefix, show_meta ) # type: ignore + def show(self, style: Optional[str] = None) -> None: + """ + A sugar for print highlighted TVM script. + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + # Use deferred import to avoid circular import while keeping cprint under tvm/script + cprint(self, style=style) + def get_attr(self, attr_key): """Get the IRModule attribute. diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py new file mode 100644 index 000000000000..03476ba60cd2 --- /dev/null +++ b/python/tvm/script/highlight.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Highlight printed TVM script. +""" + +from typing import Union, Optional +import warnings +import sys + +from tvm.ir import IRModule +from tvm.tir import PrimFunc + + +def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None: + """ + Print highlighted TVM script string with Pygments + Parameters + ---------- + printable : Union[IRModule, PrimFunc] + The TVM script to be printed + style : str, optional + Printing style, auto-detected if None. + Notes + ----- + The style parameter follows the Pygments style names or Style objects. Three + built-in styles are extended: "light", "dark" and "ansi". By default, "light" + will be used for notebook environment and terminal style will be "ansi" for + better style consistency. As an fallback when the optional Pygment library is + not installed, plain text will be printed with a one-time warning to suggest + installing the Pygment library. Other Pygment styles can be found in + https://pygments.org/styles/ + """ + + try: + # pylint: disable=import-outside-toplevel + import pygments + from pygments import highlight + from pygments.lexers.python import Python3Lexer + from pygments.formatters import Terminal256Formatter + from pygments.style import Style + from pygments.token import Keyword, Name, Comment, String, Number, Operator + from packaging import version + + if version.parse(pygments.__version__) < version.parse("2.4.0"): + raise ImportError("Required Pygments version >= 2.4.0 but got " + pygments.__version__) + except ImportError as err: + with warnings.catch_warnings(): + warnings.simplefilter("once", UserWarning) + install_cmd = sys.executable + ' -m pip install "Pygments>=2.4.0" --upgrade --user' + warnings.warn( + str(err) + + "\n" + + "To print highlighted TVM script, please install Pygments:\n" + + install_cmd, + category=UserWarning, + ) + print(printable.script()) + else: + + class JupyterLight(Style): + """A Jupyter-Notebook-like Pygments style configuration (aka. "dark")""" + + styles = { + Keyword: "bold #008000", + Keyword.Type: "nobold #008000", + Name.Function: "#0000FF", + Name.Class: "bold #0000FF", + Name.Decorator: "#AA22FF", + String: "#BA2121", + Number: "#008000", + Operator: "bold #AA22FF", + Operator.Word: "bold #008000", + Comment: "italic #007979", + } + + class VSCDark(Style): + """A VSCode-Dark-like Pygments style configuration (aka. "dark")""" + + styles = { + Keyword: "bold #c586c0", + Keyword.Type: "#82aaff", + Keyword.Namespace: "#4ec9b0", + Name.Class: "bold #569cd6", + Name.Function: "bold #dcdcaa", + Name.Decorator: "italic #fe4ef3", + String: "#ce9178", + Number: "#b5cea8", + Operator: "#bbbbbb", + Operator.Word: "#569cd6", + Comment: "italic #6a9956", + } + + class AnsiTerminalDefault(Style): + """The default style for terminal display with ANSI colors (aka. "ansi")""" + + styles = { + Keyword: "bold ansigreen", + Keyword.Type: "nobold ansigreen", + Name.Class: "bold ansiblue", + Name.Function: "bold ansiblue", + Name.Decorator: "italic ansibrightmagenta", + String: "ansiyellow", + Number: "ansibrightgreen", + Operator: "bold ansimagenta", + Operator.Word: "bold ansigreen", + Comment: "italic ansibrightblack", + } + + if style is None: + # choose style automatically according to the environment: + if "ipykernel" in sys.modules: # in notebook env. + style = JupyterLight + else: # in a terminal or something. + style = AnsiTerminalDefault + elif style == "light": + style = JupyterLight + elif style == "dark": + style = VSCDark + elif style == "ansi": + style = AnsiTerminalDefault + + print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style))) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index a921c5b9fc40..f06376147b9a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -195,6 +195,19 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: self, tir_prefix, show_meta ) # type: ignore + def show(self, style: Optional[str] = None) -> None: + """ + A sugar for print highlighted TVM script. + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + # Use deferred import to avoid circular import while keeping cprint under tvm/script + cprint(self, style=style) + @tvm._ffi.register_object("tir.TensorIntrin") class TensorIntrin(Object): diff --git a/tests/python/unittest/test_tvmscript_printer_highlight.py b/tests/python/unittest/test_tvmscript_printer_highlight.py new file mode 100644 index 000000000000..cc3469a2ceea --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_highlight.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.script import tir as T + + +def test_highlight_script(): + @tvm.script.ir_module + class Module: + @T.prim_func + def main( # type: ignore + a: T.handle, + b: T.handle, + c: T.handle, + ) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + C = T.match_buffer(c, [16, 128, 128]) + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("matmul"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + with T.init(): + C[vn, vi, vj] = 0.0 # type: ignore + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + Module.show() + Module["main"].show() + Module["main"].show(style="light") + Module["main"].show(style="dark") + Module["main"].show(style="ansi")