diff --git a/python/tvm/script/_parser/core/__init__.py b/python/tvm/script/_parser/core/__init__.py index ae1521006d9b..94d8dab0322d 100644 --- a/python/tvm/script/_parser/core/__init__.py +++ b/python/tvm/script/_parser/core/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """The core parser infra""" -from . import diagnostics, doc, doc_core, utils +from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils diff --git a/python/tvm/script/_parser/core/dispatch.py b/python/tvm/script/_parser/core/dispatch.py new file mode 100644 index 000000000000..f803be05de92 --- /dev/null +++ b/python/tvm/script/_parser/core/dispatch.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Parser dispatching infrastructure""" + +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type + +from .doc import AST + +if TYPE_CHECKING: + from .parser import Parser + + +ParseMethod = Callable[["Parser", AST], None] +ParseVTable: Dict[Tuple[str, str], ParseMethod] = {} + +OpMethod = Callable[..., Any] +OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {} + + +def register(token: str, type_name: str): + """Register a method for a dispatch token and type name. + + Parameters + ---------- + token : str + The token for IR, e.g., T for TIR and R for Relax. + + type_name : str + The type name of AST node, e.g., FunctionDef, With, For. + + Returns + ------- + func : callable + The function to register dispatched method of parsing + corresponding token and AST node type. + """ + + def func(method: ParseMethod): + """Register a method in parser virtual table. + + Parameters + ---------- + method : ParseMethod + The dispatched method to be registered in parser virtual table. + """ + ParseVTable[(token, type_name)] = method + + return func + + +def get( + token: str, + type_name: str, + default: Optional[ParseMethod] = None, +) -> Optional[ParseMethod]: + """Get a registered method for a dispatch token and type name, + or return a default method if no registered methods with this dispatch token and type name. + + Parameters + ---------- + token : str + The token for IR, e.g., T for TIR and R for Relax. + + type_name : str + The type name of AST node, e.g., FunctionDef, With, For. + + default : Optional[ParseMethod] + The default method when no registered methods with this dispatch token and type name. + + Returns + ------- + func : Optional[ParseMethod] + The dispatched method of parsing corresponding token and AST node type. + """ + return ParseVTable.get((token, type_name), default) + + +def register_op(operand_type: Type, op_node_type: AST, operand_index: int): + """Register a method for a operand type, AST operator node and operand index. + + Parameters + ---------- + operand_type : Type + The type of operands, e.g., tir.PrimExpr, tir.IterVar. + + op_node_type : AST + The doc AST operator node type, e.g., doc.Add, doc.Eq. + + operand_index : int + The operand index, i.e., 0 for left operand and 1 for right operand. + + Returns + ------- + func : callable + The function to register dispatched method of parsing + corresponding a operand type, AST operator node and operand index. + """ + + def func(method: OpMethod): + """Register a method in parser operator virtual table. + + Parameters + ---------- + method : ParseMethod + The dispatched method to be registered in parser operator virtual table. + """ + OpVTable[(operand_type, op_node_type, operand_index)] = method + + return func + + +def get_op( + operand_type: Type, + op_node_type: Type, + operand_index: int, + default: Optional[OpMethod] = None, +) -> Optional[OpMethod]: + """Register a method for a operand type, AST operator node and operand index. + + Parameters + ---------- + operand_type : Type + The type of operands, e.g., tir.PrimExpr, tir.IterVar. + + op_node_type : AST + The doc AST operator node type, e.g., doc.Add, doc.Eq. + + operand_index : int + The operand index, i.e., 0 for left operand and 1 for right operand. + + + default : Optional[OpMethod] + The default method when no registered methods with this operand type, + AST operator node and operand index. + + Returns + ------- + func : Optional[OpMethod] + The function to register dispatched method of parsing + corresponding a operand type, AST operator node and operand index. + """ + return OpVTable.get((operand_type, op_node_type, operand_index), default) diff --git a/python/tvm/script/_parser/core/entry.py b/python/tvm/script/_parser/core/entry.py new file mode 100644 index 000000000000..a0974c8fd419 --- /dev/null +++ b/python/tvm/script/_parser/core/entry.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The entry point of TVM parser.""" + +from typing import Any, Dict, Union + +from ...ir_builder import IRBuilder +from . import doc +from .diagnostics import Source +from .parser import Parser + + +def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: + """Register a method for a operand type, AST operator node and operand index. + + Parameters + ---------- + program : Union[doc.AST, Any, str] + The TVMScript code to parse. + + extra_vars : Dict[str, Any] + The extra variable table for parsing. + + Returns + ------- + func : Any + The parsed TVMScript program. + """ + + source = Source(program) + parser = Parser(source) + with IRBuilder() as builder: + parser.parse(extra_vars=extra_vars) + return builder.get() diff --git a/python/tvm/script/_parser/core/evaluator.py b/python/tvm/script/_parser/core/evaluator.py new file mode 100644 index 000000000000..3a72a3c33106 --- /dev/null +++ b/python/tvm/script/_parser/core/evaluator.py @@ -0,0 +1,509 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AST Evaluation""" + +import ast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union + +from . import dispatch, doc + +if TYPE_CHECKING: + from .parser import Parser + +DEFAULT_OP: Dict[Type, Callable[..., Any]] = { + doc.Add: lambda a, b: a + b, + doc.Sub: lambda a, b: a - b, + doc.Mult: lambda a, b: a * b, + doc.Div: lambda a, b: a / b, + doc.FloorDiv: lambda a, b: a // b, + doc.Mod: lambda a, b: a % b, + doc.LShift: lambda a, b: a << b, + doc.RShift: lambda a, b: a >> b, + doc.BitOr: lambda a, b: a | b, + doc.BitXor: lambda a, b: a ^ b, + doc.BitAnd: lambda a, b: a & b, + doc.MatMult: lambda a, b: a @ b, + doc.Pow: lambda a, b: a**b, + doc.Eq: lambda a, b: a == b, + doc.NotEq: lambda a, b: a != b, + doc.Lt: lambda a, b: a < b, + doc.LtE: lambda a, b: a <= b, + doc.Gt: lambda a, b: a > b, + doc.GtE: lambda a, b: a >= b, + doc.Is: lambda a, b: a is b, + doc.IsNot: lambda a, b: a is not b, + doc.In: lambda a, b: a in b, + doc.NotIn: lambda a, b: a not in b, + doc.And: lambda a, b: a and b, + doc.Or: lambda a, b: a or b, + doc.Invert: lambda a: ~a, + doc.Not: lambda a: not a, + doc.UAdd: lambda a: +a, + doc.USub: lambda a: -a, +} + + +class ExprEvaluator: + """Expression evaluator for TVMScript parser. + + Parameters + ---------- + parser : Parser + The parser bound with the evaluator. + + value_table : Dict[str, Any] + The value table for expression evaluation. + + new_value_count : int + The count for ntermediate result added during evaluation. + """ + + parser: "Parser" + value_table: Dict[str, Any] + new_value_count: int + + def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None: + super().__init__() + self.parser = parser + self.value_table = value_table + self.new_value_count = 0 + + @staticmethod + def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: + """Expression evaluation for TVMScript parser. + + Parameters + ---------- + parser : Parser + The parser bound with the evaluator. + + value_table : Dict[str, Any] + The value table for expression evaluation. + + node : doc.AST + The root node of AST tree node of expression to evaluate. + + Returns + ------- + res : Any + The evaluation result. + """ + self = ExprEvaluator(parser, value_table) + result = self._visit(node) # pylint: disable=protected-access + if isinstance(result, doc.Name): + if result.id not in self.value_table: + self.parser.report_error(result, f"Undefined variable: {result.id}") + return self.value_table[result.id] + if isinstance(result, doc.Constant): + return result.value + raise TypeError(f"Unexpected result type: {type(result)}") + + def _add_intermediate_result(self, value: Any) -> doc.Name: + """Add intermediate result during evaluation into value table. + + Parameters + ---------- + value : Any + The intermediate result. + + Returns + ------- + name : doc.Name + The doc AST name node with intermediate name for intermediate result. + """ + name = f"__tvm_tmp_value_{self.new_value_count}" + self.new_value_count += 1 + self.value_table[name] = value + lineno = 0 + col_offset = 0 + return doc.Name( + id=name, + ctx=doc.Load( + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ), + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ) + + def _visit(self, node: doc.AST) -> Any: + """General doc AST node visiting method for expression evaluation. + + Parameters + ---------- + node : doc.AST + The root node of AST tree node of expression to evaluate. + + Returns + ------- + res : Any + The evaluation result. + """ + if isinstance(node, list): + return [self._visit(n) for n in node] + if isinstance(node, tuple): + return tuple(self._visit(n) for n in node) + assert isinstance(node, doc.AST) + if isinstance(node, doc.Name): + if node.id not in self.value_table: + self.parser.report_error(node, f"Undefined variable: {node.id}") + return node + if isinstance( + node, + ( + doc.Constant, + doc.expr_context, + doc.operator, + doc.boolop, + doc.unaryop, + doc.cmpop, + ), + ): + return node + if not isinstance(node, (doc.expr, doc.slice)): + return node + if isinstance(node, doc.Lambda): + return self._eval_lambda(node) + fields = {} + for field in node.__class__._FIELDS: # pylint: disable=protected-access + attr = getattr(node, field) + if isinstance(attr, (doc.AST, tuple, list)): + fields[field] = self._visit(attr) + else: + fields[field] = attr + try: + if isinstance(node, doc.BoolOp): + value = self._eval_bool_op(fields) + elif isinstance(node, doc.Compare): + value = self._eval_compare(fields) + elif isinstance(node, doc.UnaryOp): + value = self._eval_unary_op(fields) + elif isinstance(node, doc.BinOp): + value = self._eval_bin_op(fields) + elif isinstance(node, doc.Slice): + value = self._eval_slice(fields) + else: + value = self._eval_expr(node.__class__(**fields)) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_lambda(self, node: doc.Lambda) -> Any: + """The doc AST lambda node evaluating method. + + Parameters + ---------- + node : doc.Lambda + The root node of AST tree node of expression to evaluate. + + Returns + ------- + res : Any + The evaluation result. + """ + try: + value = self._eval_expr(node) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_bool_op(self, fields: Dict[str, Any]) -> Any: + """The doc AST boolean operator node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of boolean operation information, + e.g., operator types, operand values. + + Returns + ------- + res : Any + The evaluation result. + """ + op = fields["op"] + if not isinstance(op, (doc.And, doc.Or)): + raise TypeError(f"Unexpected operator: {op}") + value = self._eval_expr(fields["values"][0]) + for rhs in fields["values"][1:]: + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_compare(self, fields: Dict[str, Any]) -> Any: + """The doc AST comparison operation node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of comparison operation information, + e.g., operator types, operand values. + + Returns + ------- + res : Any + The evaluation result. + """ + value = self._eval_expr(fields["left"]) + for op, rhs in zip(fields["ops"], fields["comparators"]): + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: + """The doc AST unary operation node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of unary operation information, + e.g., operator types, operand values. + + Returns + ------- + res : Any + The evaluation result. + """ + value = self._eval_expr(fields["operand"]) + value = _eval_op(fields["op"], values=[value]) + return value + + def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: + """The doc AST binary operation node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of binary operation information, + e.g., operator types, operand values. + + Returns + ------- + res : Any + The evaluation result. + """ + return _eval_op( + fields["op"], + values=[ + self._eval_expr(fields["left"]), + self._eval_expr(fields["right"]), + ], + ) + + def _eval_slice(self, fields: Dict[str, Any]) -> slice: + """The doc AST slice node evaluating method. + + Parameters + ---------- + fields : Dict[str, Any] + The dictionary of slice information, + e.g., lower bound, upper bound, step. + + Returns + ------- + res : slice + The evaluation result. + """ + lower, upper, step = fields["lower"], fields["upper"], fields["step"] + + lower = self._eval_expr(lower) if lower is not None else None + upper = self._eval_expr(upper) if upper is not None else None + step = self._eval_expr(step) if step is not None else None + + return slice(lower, upper, step) + + def _eval_expr(self, v: Any) -> Any: + """The doc AST expression node evaluating method. + + Parameters + ---------- + v : Any + The root node of AST tree node of expression to evaluate. + + Returns + ------- + res : Any + The evaluation result. + """ + return _eval_expr(v, self.value_table) + + +def eval_expr( + parser: "Parser", + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + """Expression evaluation for TVMScript parser. + + Parameters + ---------- + parser : Parser + The parser bound with the evaluator. + + node : Union[doc.expr, doc.Expression] + The root node of AST tree node of expression to evaluate. + + dict_globals : Optional[Dict[str, Any]] + The optional global value table for expression evaluation. + + Returns + ------- + res : Any + The evaluation result. + """ + value_table = {} + if dict_globals is not None: + value_table.update(dict_globals) + return ExprEvaluator.eval(parser, value_table, node) + + +def eval_assign( + parser: "Parser", + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + """Expression assignment evaluation for TVMScript parser. + + Parameters + ---------- + parser : Parser + The parser bound with the evaluator. + + target : doc.expr + The root node of AST tree node of assigned expression to evaluate. + + source : Any + The source to be assigned with evaluated expression. + + Returns + ------- + res : Any + The evaluation result. + """ + try: + return _eval_assign(target, source) + except Exception as e: # pylint: disable=broad-except,invalid-name + parser.report_error(target, f"Failed to evaluate assignment: {str(e)}") + raise + + +def _eval_expr( + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + """Expression evaluation implementation for TVMScript parser. + + Parameters + ---------- + node : Union[doc.expr, doc.Expression] + The root node of AST tree node of expression to evaluate. + + dict_globals : Optional[Dict[str, Any]] + The optional global value table for expression evaluation. + + Returns + ------- + res : Any + The evaluation result. + """ + node = doc.from_doc(node) + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node) + if dict_globals is None: + dict_globals = {} + node = ast.fix_missing_locations(node) + exe = compile(node, filename="", mode="eval") + return eval(exe, dict_globals) # pylint: disable=eval-used + + +def _eval_op( + op: doc.AST, + values: List[Any], +): + """Operation expression evaluation implementation for TVMScript parser. + + Parameters + ---------- + op : doc.AST + The root node of AST tree node of operation expression to evaluate. + + values : List[Any] + The list of values of operands. + + Returns + ------- + res : Any + The evaluation result. + """ + op_type = type(op) # pylint: disable=protected-access + for i, v in enumerate(values): + v_type = getattr(type(v), "_dispatch_type", None) + if v_type is None: + continue + f = dispatch.get_op( + operand_type=v_type, op_node_type=op_type, operand_index=i, default=None + ) + if f is not None: + return f(*values) + return DEFAULT_OP[op_type](*values) + + +def _eval_assign( + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + """Expression assignment evaluation implementation for TVMScript parser. + + Parameters + ---------- + target : doc.expr + The root node of AST tree node of assigned expression to evaluate. + + source : Any + The source to be assigned with evaluated expression. + + Returns + ------- + res : Any + The evaluation result. + """ + target = doc.from_doc(target) + assert isinstance(target, ast.expr) + RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name + rhs_var_name = RHS_VAR_NAME + dict_locals = {rhs_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name( + id=rhs_var_name, + ctx=ast.Load(), + ), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, dict_locals) # pylint: disable=exec-used + del dict_locals[rhs_var_name] + return dict_locals diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/_parser/core/parser.py new file mode 100644 index 000000000000..daf95cb3cd1b --- /dev/null +++ b/python/tvm/script/_parser/core/parser.py @@ -0,0 +1,647 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The core parser""" + +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Set, Union +from tvm._ffi.base import TVMError + +from tvm.error import DiagnosticError + +from . import dispatch, doc +from .diagnostics import Diagnostics, Source +from .evaluator import eval_assign, eval_expr + +DEFAULT_VISIT = { + "Interactive", + "Module", + "Expression", + "Pass", +} + + +def _deferred(exit_f: Callable[[], None]): + """Created context with certain exit function. + + Parameters + ---------- + exit_f : Callable[[], None] + The function to call when exiting the context. + + Returns + ------- + res : Any + The created context. + """ + + @contextmanager + def context(): + try: + yield + finally: + exit_f() + + return context() + + +class VarTableFrame: + """The variable table frame. + A frame of variable table stores the variables created in one block or scope. + + Parameters + ---------- + vars : Set[str] + The set of variable names in the variable table frame. + """ + + vars: Set[str] + + def __init__(self): + self.vars = set() + + def add(self, var: str): + """Add a new variable into variable table frame. + + Parameters + ---------- + var : str + The name of new variable. + """ + if var in self.vars: + raise ValueError(f"Variable {var} already defined in current scope") + self.vars.add(var) + + def pop_all(self, fn_pop: Callable[[str], None]): + """Pop out all variable in variable table frame. + + Parameters + ---------- + fn_pop : Callable[[str], None] + The methods to call when popping each variable. + """ + for var in self.vars: + fn_pop(var) + self.vars.clear() + + +class VarTable: + """The variable table. + A variable table stores the all variables when parsing TVMScript. + + Parameters + ---------- + frames : List[VarTableFrame] + The list or stack of variable table frame. + + name2value : Dict[str, List[Any]] + The dictionary for variable table name-based query. + """ + + frames: List[VarTableFrame] + name2value: Dict[str, List[Any]] + + def __init__(self): + self.frames = [] + self.name2value = defaultdict(list) + + def with_frame(self): + """Create a new variable table frame as with statement. + + Returns + ------- + res : Any + The context with new variable table frame. + """ + + def pop_frame(): + frame = self.frames.pop() + frame.pop_all(lambda name: self.name2value[name].pop()) + + self.frames.append(VarTableFrame()) + return _deferred(pop_frame) + + def add(self, var: str, value: Any, allow_shadowing: bool = False): + """Add a new variable to variable table. + + Parameters + ---------- + var : str + The name of variable. + + value : Any + The value of variable. + + allow_shadowing : bool + The options of whether variable shadowing allwed for this variable. + """ + # Skip if the key and value are equal to those in the var_table + if self.name2value[var] and self.name2value[var][-1] == value: + return + if allow_shadowing and var in self.frames[-1].vars: + # Shadowing + self.name2value[var][-1] = value + else: + self.frames[-1].add(var) + self.name2value[var].append(value) + + def get(self) -> Dict[str, Any]: + """Get a variable dictionary of latest variables. + + Returns + ------- + res : Any + The variable dictionary copy of latest variables. + """ + return {key: values[-1] for key, values in self.name2value.items() if values} + + def exist(self, value: Any) -> bool: + """Check if any value exists in variable table. + + Parameters + ---------- + value : Any + The value of variable. + + Returns + ------- + res : bool + The existence of the value. + """ + for v in self.name2value.values(): + if v is value: + return True + return False + + +def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: + def _wrapper(self: "Parser", node: doc.AST) -> None: + try: + return func(self, node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, e) + raise + + return _wrapper + + +def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return _dispatch_wrapper(func) + return _dispatch_wrapper(lambda self, node: self.generic_visit(node)) + + +class Parser(doc.NodeVisitor): + """The TVMScript parser + + Parameters + ---------- + diag : Diagnostics + The diagnostics for error reporting. + + dispatch_tokens : List[str] + The list of dispatching tokens to dispatching parsing method + of different IRs and different doc AST structure. + + var_table : VarTable + The variable table for parsing. + """ + + diag: Diagnostics + dispatch_tokens: List[str] + var_table: VarTable + + def __init__(self, source: Source) -> None: + self.diag = Diagnostics(source) + self.dispatch_tokens = ["default"] + self.var_table = VarTable() + + def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: + """The main parse method for parser. + + Parameters + ---------- + extra_vars : Optional[Dict[str, Any]] + The optional global value table for parsing. + + Returns + ------- + res : Any + The doc AST node visiting result. + """ + if extra_vars is None: + extra_vars = {} + with self.var_table.with_frame(): + for k, v in extra_vars.items(): + self.var_table.add(k, v) + node = self.diag.source.as_ast() + self.visit(node) + + def with_dispatch_token(self, token: str): + """Add a new dispatching token as with statement. + + Parameters + ---------- + token : str + The dispathing token. + + Returns + ------- + res : Any + The context with new dispatching token. + """ + + def pop_token(): + self.dispatch_tokens.pop() + + self.dispatch_tokens.append(token) + return _deferred(pop_token) + + def eval_expr( + self, + node: Union[doc.Expression, doc.expr], + extra_vars: Optional[Dict[str, Any]] = None, + ) -> Any: + """Expression evaluation when parsing. + + Parameters + ---------- + node : Union[doc.expr, doc.Expression] + The root node of AST tree node of expression to evaluate. + + extra_vars : Optional[Dict[str, Any]] + The optional global value table for expression evaluation. + + Returns + ------- + res : Any + The evaluation result. + """ + var_values = self.var_table.get() + if extra_vars is not None: + for k, v in extra_vars.items(): + var_values[k] = v + return eval_expr(self, node, var_values) + + def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: + """Check whether duplicate lhs exists in assignment. + + Parameters + ---------- + target : doc.expr + The doc AST expr node for lhs. + + Returns + ------- + res : Union[bool, Set[str]] + The result of true if duplicate lhs exists, + or the set of lhs names if no duplicate lhs exists. + """ + if isinstance(target, (doc.Tuple, doc.List)): + vars: Set[str] = set() # pylint: disable=redefined-builtin + for i in target.elts: + res = self._duplicate_lhs_check(i) + if isinstance(res, bool) and res: + return True + assert isinstance(res, set) + if vars & res: + return True + vars = vars.union(res) + return vars + elif isinstance(target, doc.Name): + return {target.id} + else: + self.report_error(target, "Invalid type in assign statement") + raise NotImplementedError + + def eval_assign( + self, + target: doc.expr, + source: Any, + bind_value: Callable[["Parser", doc.expr, str, Any], Any], + allow_shadowing: bool = False, + ) -> Dict[str, Any]: + """Expression assignment evaluation when parsing. + + Parameters + ---------- + target : doc.expr + The root node of AST tree node of assigned expression to evaluate. + + source : Any + The source to be assigned with evaluated expression. + + bind_value : Callable[["Parser", doc.expr, str, Any], Any] + The value binding method when assigning the values to variables. + + allow_shadowing : bool + The options of whether variable shadowing allwed for assignment. + + Returns + ------- + res : Dict[str, Any] + The dirctionary of assignment result. + """ + if self._duplicate_lhs_check(target) is True: + self.report_error(target, "Duplicate vars assigned.") + var_values = eval_assign(self, target, source) + for k, v in var_values.items(): + var = bind_value(self, target, k, v) + self.var_table.add(k, var, allow_shadowing) + return var_values + + def report_error( + self, node: doc.AST, err: Union[Exception, str] + ) -> None: # pylint: disable=no-self-use + """The error reporting when parsing. + + Parameters + ---------- + node : doc.AST + The doc AST node with errors. + + err: Union[Exception, str] + The error to report. + """ + # Only take the last line of the error message + if isinstance(err, TVMError): + msg = list(filter(None, str(err).split("\n")))[-1] + else: + msg = str(err) + self.diag.error(node, msg) + + def visit(self, node: doc.AST) -> None: + """The general visiting method. + + Parameters + ---------- + node : doc.AST + The doc AST node. + + Returns + ------- + res : Any + The visiting result. + """ + if isinstance(node, (list, tuple)): + for item in node: + self.visit(item) + return + if not isinstance(node, doc.AST): + return + name = node.__class__.__name__.split(".")[-1] + if name in DEFAULT_VISIT: + func = self.generic_visit + else: + func = getattr(self, "visit_" + name, None) + if func is None: + raise NotImplementedError(f"Visitor of AST node is not implemented: {name}") + try: + func(node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, str(e)) + raise + + def visit_body(self, node: List[doc.stmt]) -> Any: + """The general body visiting method. + + Parameters + ---------- + node : List[doc.stmt] + The list of statements in body. + + Returns + ------- + res : Any + The visiting result. + """ + for stmt in node: + self.visit(stmt) + + def visit_tvm_annotation(self, node: doc.expr) -> Any: + """The general TVM annotation visiting method. + + Parameters + ---------- + node : doc.expr + The doc AST expr node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "tvm_annotation")(self, node) + + def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name + """The general function definition visiting method. + + Parameters + ---------- + node : doc.FunctionDef + The doc AST function definition node. + + Returns + ------- + res : Any + The visiting result. + """ + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + token = decorator.dispatch_token + func = dispatch.get(token=token, type_name="FunctionDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + _dispatch_wrapper(func)(self, node) + + def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name + """The general class definition visiting method. + + Parameters + ---------- + node : doc.ClassDef + The doc AST class definition node. + + Returns + ------- + res : Any + The visiting result. + """ + func = dispatch.get(token="ir", type_name="ClassDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + _dispatch_wrapper(func)(self, node) + + def visit_arguments(self, node: doc.arguments) -> Any: + """The general arguments visiting method. + + Parameters + ---------- + node : doc.arguments + The doc AST arguments node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "arguments")(self, node) + + def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name + """The general for visiting method. + + Parameters + ---------- + node : doc.For + The doc AST for node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "For")(self, node) + + def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name + """The general while visiting method. + + Parameters + ---------- + node : doc.While + The doc AST while node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "While")(self, node) + + def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name + """The general with visiting method. + + Parameters + ---------- + node : doc.With + The doc AST with node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "With")(self, node) + + def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name + """The general assign visiting method. + + Parameters + ---------- + node : doc.Assign + The doc AST assign node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Assign")(self, node) + + def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name + """The general expression visiting method. + + Parameters + ---------- + node : doc.Expr + The doc AST exprssion node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Expr")(self, node) + + def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name + """The general if visiting method. + + Parameters + ---------- + node : doc.If + The doc AST if node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "If")(self, node) + + def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name + """The general augmented assignment visiting method. + + Parameters + ---------- + node : doc.AugAssign + The doc AST augmented assignment node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "AugAssign")(self, node) + + def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name + """The general assert visiting method. + + Parameters + ---------- + node : doc.Assert + The doc AST assert node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Assert")(self, node) + + def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name + """The general return visiting method. + + Parameters + ---------- + node : doc.Return + The doc AST return node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "Return")(self, node) diff --git a/tests/python/unittest/test_tvmscript_parser_evaluator.py b/tests/python/unittest/test_tvmscript_parser_evaluator.py new file mode 100644 index 000000000000..4d6590306050 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_parser_evaluator.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.parser.evaluator""" +import pytest +import tvm.testing +from tvm.script._parser.core.diagnostics import Source +from tvm.script._parser.core.evaluator import ExprEvaluator + + +def _calc(expr, extra_vars=None): + if extra_vars is None: + extra_vars = {} + source = Source(expr) + mod_ast = source.as_ast() + mod_body_ast = mod_ast.body + expr_stmt_ast = mod_body_ast[0] + expr_ast = expr_stmt_ast.value + return ExprEvaluator.eval(None, extra_vars, expr_ast) + + +def test_evaluator_basic(): + assert _calc("1, 3.14, True, 'str'") == (1, 3.14, True, "str") + + +def test_evaluator_op(): + assert _calc("1 + 2, 1 - 2, 1 * 2, 1 / 2") == (3, -1, 2, 0.5) + + +def test_evaluator_value_table(): + res = _calc("a + b, a - b, a * b, a / b", {"a": 1, "b": 2}) + a, b = 1, 2 + assert res == (a + b, a - b, a * b, a / b) + + +def test_evaluator_func_call(): + def func(a, b): + return a + b, a - b, a * b, a / b + + assert _calc("func(1, 2)", {"func": func}) == func(1, 2) + + +def test_evaluator_slice(): + res = _calc("a, a[1:], a[:5], a[1: 5], a[1: 5: 2]", {"a": [1, 2, 3, 4, 5, 6]}) + a = [1, 2, 3, 4, 5, 6] + assert res == (a, a[1:], a[:5], a[1:5], a[1:5:2]) + + +if __name__ == "__main__": + tvm.testing.main()