From f5755bd8055a85a91561267f80c92d1b3414f41d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 2 May 2026 22:09:36 +0900 Subject: [PATCH 1/2] [Relax][Frontend] Add ParameterList and ParameterDict containers Add first-class nn.ParameterList and nn.ParameterDict containers for Relax frontend parameters. The new containers validate explicit nn.Parameter values, participate in parameter/state traversal, export stable dot-separated names, and are handled by nn.Mutator traversal. Add tests for public exports, list/dict behavior, nested traversal, state_dict/load_state_dict, dtype conversion, export parameter names, and mutator naming. --- python/tvm/relax/frontend/nn/__init__.py | 12 +- python/tvm/relax/frontend/nn/core.py | 111 ++++++++- python/tvm/relax/frontend/nn/visitor.py | 62 ++++- .../python/relax/test_frontend_nn_mutator.py | 36 +++ .../test_frontend_nn_parameter_containers.py | 223 ++++++++++++++++++ 5 files changed, 438 insertions(+), 6 deletions(-) create mode 100644 tests/python/relax/test_frontend_nn_parameter_containers.py diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index 282944af9833..1763ca152f5f 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -19,7 +19,17 @@ # pylint: disable=redefined-builtin from . import op, spec -from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter, Tensor +from .core import ( + Effect, + Module, + ModuleDict, + ModuleList, + Object, + Parameter, + ParameterDict, + ParameterList, + Tensor, +) from .exporter import add_extern from .extern import ExternModule, ObjectModule, SourceModule from .modules import ( diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index f3886e94cbcf..f9a61db050be 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -625,6 +625,59 @@ def to(self, dtype: str | None = None) -> None: # pylint: disable=invalid-name module.to(dtype=dtype) +class ParameterDict(Module): + """Holds parameters in a dict.""" + + def __init__( + self, + params: OrderedDict[str, Parameter] | dict[str, Parameter] | None = None, + ): + self.params: OrderedDict[str, Parameter] = OrderedDict() + if params is not None: + self.update(params) + + def __iter__(self) -> Iterator[str]: + return iter(self.params) + + def __getitem__(self, key: str) -> Parameter: + return self.params[key] + + def __setitem__(self, key: str, param: Parameter) -> None: + self.params[key] = param + + def __len__(self) -> int: + return len(self.params) + + def keys(self) -> Iterator[str]: + return self.params.keys() + + def values(self) -> Iterator[Parameter]: + return self.params.values() + + def items(self) -> Iterator[tuple[str, Parameter]]: + return self.params.items() + + def get(self, key: str, default: Parameter | None = None) -> Parameter | None: + return self.params.get(key, default) + + def update(self, params: dict[str, Parameter]) -> None: + for key, param in params.items(): + self[key] = param + + def clear(self) -> None: + self.params.clear() + + def pop(self, key: str) -> Parameter: + return self.params.pop(key) + + def __contains__(self, key: str) -> bool: + return key in self.params + + def to(self, dtype: str | None = None) -> None: # pylint: disable=invalid-name + for param in self.params.values(): + param.to(dtype=dtype) + + class ModuleList(Module): """Holds submodules in a list.""" @@ -658,6 +711,40 @@ def forward(self, x): # pylint: disable=invalid-name return x +class ParameterList(Module): + """Holds parameters in a list.""" + + def __init__(self, params: list[Parameter] | None = None): + self.params: list[Parameter] = [] + if params is not None: + self.extend(params) + + def __iter__(self) -> Iterator[Parameter]: + return iter(self.params) + + def __getitem__(self, idx: int) -> Parameter: + return self.params[idx] + + def __setitem__(self, idx: int, param: Parameter) -> None: + self.params[idx] = param + + def __len__(self) -> int: + return len(self.params) + + def append(self, param: Parameter) -> None: + """Add a parameter to the end of the ParameterList""" + self.params.append(param) + + def extend(self, params: list[Parameter]) -> None: + """Add parameters to the end of the ParameterList""" + for param in params: + self.append(param) + + def to(self, dtype: str | None = None) -> None: # pylint: disable=invalid-name + for param in self.params: + param.to(dtype=dtype) + + def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]: """Wrap the given relax.Expr, emit it using the current BlockBuilder, and automatically handle nested cases if the expr represents a Tuple. @@ -692,7 +779,17 @@ def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]: def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]): """Find attributes that satisfy the condition recursively""" - if isinstance(root, ModuleList): + if isinstance(root, ParameterList): + for i, param in enumerate(root): + if condition_yield(param): + yield prefix + f"{i}", param + return + elif isinstance(root, ParameterDict): + for name, param in root.items(): + if condition_yield(param): + yield prefix + name, param + return + elif isinstance(root, ModuleList): for i, subitem in enumerate(root): yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield) return @@ -703,6 +800,18 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] for name, item in root.__dict__.items(): if condition_yield(item): yield prefix + name, item + elif isinstance(item, ParameterList): + yield from _attribute_finder( + item, + prefix + name + ".", + condition_yield, + ) + elif isinstance(item, ParameterDict): + yield from _attribute_finder( + item, + prefix + name + ".", + condition_yield, + ) elif isinstance(item, ModuleList): yield from _attribute_finder( item, diff --git a/python/tvm/relax/frontend/nn/visitor.py b/python/tvm/relax/frontend/nn/visitor.py index e3279ceae50f..69583eaae8d3 100644 --- a/python/tvm/relax/frontend/nn/visitor.py +++ b/python/tvm/relax/frontend/nn/visitor.py @@ -116,6 +116,42 @@ def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any: """ return self.visit(name, node) + def visit_parameterdict(self, name: str, node: nn.ParameterDict) -> Any: + """The base visiting method for mutation of nn.ParameterDict nodes. + + Parameters + ---------- + name : str + The name of the current node in parent's attribute. + + node : nn.ParameterDict + The current node of nn.ParameterDict to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + return self.visit(name, node) + + def visit_parameterlist(self, name: str, node: nn.ParameterList) -> Any: + """The base visiting method for mutation of nn.ParameterList nodes. + + Parameters + ---------- + name : str + The name of the current node in parent's attribute. + + node : nn.ParameterList + The current node of nn.ParameterList to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + return self.visit(name, node) + def visit(self, name: str, node: Any) -> Any: """The base dispatching method for visiting of all nodes. @@ -141,9 +177,19 @@ def _get_child_name(parent: str, child: str) -> str: else: return f"{parent}.{child}" - if isinstance(node, nn.ModuleList): + if isinstance(node, nn.ParameterList): + for i in range(len(node)): + node[i] = self.visit_param(_get_child_name(name, str(i)), node[i]) + elif isinstance(node, nn.ParameterDict): + for k, v in node.items(): + node[k] = self.visit_param(_get_child_name(name, k), v) + elif isinstance(node, nn.ModuleList): for i in range(len(node)): - if isinstance(node[i], nn.ModuleDict): + if isinstance(node[i], nn.ParameterDict): + node[i] = self.visit_parameterdict(_get_child_name(name, str(i)), node[i]) + elif isinstance(node[i], nn.ParameterList): + node[i] = self.visit_parameterlist(_get_child_name(name, str(i)), node[i]) + elif isinstance(node[i], nn.ModuleDict): node[i] = self.visit_moduledict(f"{name}.{i}", node[i]) elif isinstance(node[i], nn.ModuleList): node[i] = self.visit_modulelist(f"{name}.{i}", node[i]) @@ -155,7 +201,11 @@ def _get_child_name(parent: str, child: str) -> str: node[i] = self.visit_param(f"{name}.{i}", node[i]) elif isinstance(node, nn.ModuleDict): for k, v in node.items(): - if isinstance(v, nn.ModuleDict): + if isinstance(v, nn.ParameterDict): + node[k] = self.visit_parameterdict(_get_child_name(name, k), v) + elif isinstance(v, nn.ParameterList): + node[k] = self.visit_parameterlist(_get_child_name(name, k), v) + elif isinstance(v, nn.ModuleDict): node[k] = self.visit_moduledict(_get_child_name(name, k), v) elif isinstance(v, nn.ModuleList): node[k] = self.visit_modulelist(_get_child_name(name, k), v) @@ -167,7 +217,11 @@ def _get_child_name(parent: str, child: str) -> str: node[k] = self.visit_param(_get_child_name(name, k), v) else: for key, value in node.__dict__.items(): - if isinstance(value, nn.ModuleDict): + if isinstance(value, nn.ParameterDict): + setattr(node, key, self.visit_parameterdict(_get_child_name(name, key), value)) + elif isinstance(value, nn.ParameterList): + setattr(node, key, self.visit_parameterlist(_get_child_name(name, key), value)) + elif isinstance(value, nn.ModuleDict): setattr(node, key, self.visit_moduledict(_get_child_name(name, key), value)) elif isinstance(value, nn.ModuleList): setattr(node, key, self.visit_modulelist(_get_child_name(name, key), value)) diff --git a/tests/python/relax/test_frontend_nn_mutator.py b/tests/python/relax/test_frontend_nn_mutator.py index 253e24a4eddf..23c8c9cde619 100644 --- a/tests/python/relax/test_frontend_nn_mutator.py +++ b/tests/python/relax/test_frontend_nn_mutator.py @@ -127,6 +127,42 @@ def visit_param(self, name: str, node: nn.Parameter) -> Any: mutator.visit("mod_list", mod_list) +def test_mutator_naming_parameter_containers(): + class Module(nn.Module): + def __init__(self) -> None: + super().__init__() + self.param_list = nn.ParameterList( + [ + nn.Parameter((32, 128), "float64"), + nn.Parameter((32, 128), "float32"), + ] + ) + self.param_dict = nn.ParameterDict( + { + "k0": nn.Parameter((32, 128), "float16"), + "k1": nn.Parameter((32, 128), "float8"), + } + ) + + seen = [] + + class Mutator(nn.Mutator): + def visit_param(self, name: str, node: nn.Parameter) -> Any: + seen.append((name, node.dtype)) + return node + + module = Module() + mutator = Mutator() + mutator.visit("", module) + + assert seen == [ + ("param_list.0", "float64"), + ("param_list.1", "float32"), + ("param_dict.k0", "float16"), + ("param_dict.k1", "float8"), + ] + + def test_mutator_module(): class SubModule1(nn.Module): def __init__(self) -> None: diff --git a/tests/python/relax/test_frontend_nn_parameter_containers.py b/tests/python/relax/test_frontend_nn_parameter_containers.py new file mode 100644 index 000000000000..d07a21405a61 --- /dev/null +++ b/tests/python/relax/test_frontend_nn_parameter_containers.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.relax.frontend import nn + + +class ParamContainerModule(nn.Module): + def __init__(self): + self.list_params = nn.ParameterList( + [ + nn.Parameter((4,), "float32"), + nn.Parameter((4,), "float32"), + ] + ) + self.dict_params = nn.ParameterDict( + { + "foo": nn.Parameter((4,), "float32"), + "bar": nn.Parameter((4,), "float32"), + } + ) + + +def test_parameter_list_basic_behavior(): + p0 = nn.Parameter((4,), "float32") + p1 = nn.Parameter((4,), "float32") + params = nn.ParameterList([p0]) + params.append(p1) + + assert len(params) == 2 + assert params[0] is p0 + assert list(params) == [p0, p1] + + p2 = nn.Parameter((4,), "float32") + params[1] = p2 + assert params[1] is p2 + + p3 = nn.Parameter((4,), "float32") + params.extend([p3]) + assert list(params) == [p0, p2, p3] + + +def test_parameter_dict_basic_behavior(): + p0 = nn.Parameter((4,), "float32") + p1 = nn.Parameter((4,), "float32") + params = nn.ParameterDict({"foo": p0}) + params["bar"] = p1 + + assert len(params) == 2 + assert params["foo"] is p0 + assert "bar" in params + assert list(params) == ["foo", "bar"] + assert list(params.keys()) == ["foo", "bar"] + assert list(params.values()) == [p0, p1] + assert list(params.items()) == [("foo", p0), ("bar", p1)] + assert params.get("foo") is p0 + + p2 = nn.Parameter((4,), "float32") + params.update({"baz": p2}) + assert list(params.keys()) == ["foo", "bar", "baz"] + assert params.pop("baz") is p2 + params.clear() + assert len(params) == 0 + + +def test_type_validation(): + with pytest.raises(TypeError): + nn.ParameterList([object()]) + + with pytest.raises(TypeError): + nn.ParameterDict({"bad": object()}) + + with pytest.raises(TypeError): + nn.ParameterDict({1: nn.Parameter((4,), "float32")}) + + with pytest.raises(TypeError): + nn.ParameterList()[0] = object() + + +def test_named_parameters_parameters_and_state_dict(): + m = ParamContainerModule() + + expected = [ + "list_params.0", + "list_params.1", + "dict_params.foo", + "dict_params.bar", + ] + + assert list(m.state_dict().keys()) == expected + assert [name for name, _ in m.named_parameters()] == expected + assert len(list(m.parameters())) == 4 + + +def test_nested_traversal_through_module_dict(): + class Inner(nn.Module): + def __init__(self): + self.params = nn.ParameterList([nn.Parameter((4,), "float32")]) + + class Outer(nn.Module): + def __init__(self): + self.blocks = nn.ModuleDict({"inner": Inner()}) + + m = Outer() + assert list(m.state_dict().keys()) == ["blocks.inner.params.0"] + + +def test_nested_traversal_through_module_list(): + class Inner(nn.Module): + def __init__(self): + self.params = nn.ParameterList([nn.Parameter((4,), "float32")]) + + class Outer(nn.Module): + def __init__(self): + self.blocks = nn.ModuleList([Inner()]) + + m = Outer() + assert list(m.state_dict().keys()) == ["blocks.0.params.0"] + + +def test_to_dtype(): + m = ParamContainerModule() + m.to(dtype="float16") + + assert m.list_params[0].dtype == "float16" + assert m.list_params[1].dtype == "float16" + assert m.dict_params["foo"].dtype == "float16" + assert m.dict_params["bar"].dtype == "float16" + + +def test_load_state_dict(): + m = ParamContainerModule() + p0 = nn.Parameter((4,), "float32") + p0.data = np.full((4,), 1.0, dtype="float32") + p1 = nn.Parameter((4,), "float32") + p1.data = np.full((4,), 2.0, dtype="float32") + p2 = nn.Parameter((4,), "float32") + p2.data = np.full((4,), 3.0, dtype="float32") + p3 = nn.Parameter((4,), "float32") + p3.data = np.full((4,), 4.0, dtype="float32") + state_dict = { + "list_params.0": p0, + "list_params.1": p1, + "dict_params.foo": p2, + "dict_params.bar": p3, + } + + missing_keys, unexpected_keys = m.load_state_dict(state_dict) + + assert missing_keys == [] + assert unexpected_keys == [] + tvm.testing.assert_allclose(m.list_params[0].data.numpy(), np.full((4,), 1.0, "float32")) + tvm.testing.assert_allclose(m.list_params[1].data.numpy(), np.full((4,), 2.0, "float32")) + tvm.testing.assert_allclose( + m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32") + ) + tvm.testing.assert_allclose( + m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32") + ) + + +def test_export_tvm_parameter_names(): + class M(nn.Module): + def __init__(self): + self.biases = nn.ParameterList( + [ + nn.Parameter((4,), "float32"), + nn.Parameter((4,), "float32"), + ] + ) + self.scales = nn.ParameterDict({"main": nn.Parameter((4,), "float32")}) + + def forward(self, x): + return x + self.biases[0] + self.biases[1] + self.scales["main"] + + _, params = M().export_tvm( + spec={"forward": {"x": nn.spec.Tensor((4,), "float32")}}, + debug=False, + ) + assert [name for name, _ in params] == ["biases.0", "biases.1", "scales.main"] + + +def test_mutator_parameter_container_names(): + seen = [] + + class Recorder(nn.Mutator): + def visit_param(self, name: str, node: nn.Parameter) -> Any: + seen.append(name) + return node + + m = ParamContainerModule() + Recorder().visit_module("", m) + + assert seen == [ + "list_params.0", + "list_params.1", + "dict_params.foo", + "dict_params.bar", + ] + + +if __name__ == "__main__": + tvm.testing.main() From ee3e9ea7dda1759540a0fd1239192bba77a82961 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 2 May 2026 22:09:36 +0900 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/frontend/nn/core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index f9a61db050be..3725a84d61f8 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -643,6 +643,10 @@ def __getitem__(self, key: str) -> Parameter: return self.params[key] def __setitem__(self, key: str, param: Parameter) -> None: + if not isinstance(key, str): + raise TypeError(f"ParameterDict keys must be strings, but got {type(key).__name__}") + if not isinstance(param, Parameter): + raise TypeError(f"ParameterDict values must be nn.Parameter, but got {type(param).__name__}") self.params[key] = param def __len__(self) -> int: @@ -726,6 +730,8 @@ def __getitem__(self, idx: int) -> Parameter: return self.params[idx] def __setitem__(self, idx: int, param: Parameter) -> None: + if not isinstance(param, Parameter): + raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}") self.params[idx] = param def __len__(self) -> int: @@ -733,6 +739,8 @@ def __len__(self) -> int: def append(self, param: Parameter) -> None: """Add a parameter to the end of the ParameterList""" + if not isinstance(param, Parameter): + raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}") self.params.append(param) def extend(self, params: list[Parameter]) -> None: