diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a269fdc2768..edb74b9e24f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,7 +98,6 @@ tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNA tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) tvm_option(USE_CLML "Build with CLML Codegen support" OFF) tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF) -tvm_option(USE_MSC "Enable Multi-System Compiler" OFF) tvm_option(USE_NVSHMEM "Build with NVSHMEM support" OFF) # Python package options @@ -449,7 +448,6 @@ include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) -include(cmake/modules/contrib/MSC.cmake) include(cmake/modules/contrib/vllm.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 856754c3ec24..dfbe0d217893 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -212,9 +212,6 @@ set(USE_SORT ON) set(USE_TENSORRT_CODEGEN OFF) set(USE_TENSORRT_RUNTIME OFF) -# Whether to use the Multi-System Compiler -set(USE_MSC OFF) - #Whether to use CLML codegen set(USE_CLML OFF) # USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index b023bea4696b..c544ced3cacf 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -107,7 +107,6 @@ function(add_lib_info src_file) TVM_INFO_USE_CLML="${USE_CLML}" TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}" TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}" - TVM_INFO_USE_MSC="${USE_MSC}" TVM_INFO_USE_CCACHE="${USE_CCACHE}" TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}" TVM_INFO_USE_NNAPI_CODEGEN="${USE_NNAPI_CODEGEN}" diff --git a/cmake/modules/contrib/MSC.cmake b/cmake/modules/contrib/MSC.cmake deleted file mode 100644 index 5779ea52175b..000000000000 --- a/cmake/modules/contrib/MSC.cmake +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -if(USE_MSC) - tvm_file_glob(GLOB_RECURSE MSC_CORE_SOURCE "src/contrib/msc/*.cc") - list(APPEND COMPILER_SRCS ${MSC_CORE_SOURCE}) - - tvm_file_glob(GLOB_RECURSE MSC_RUNTIME_SOURCE "src/runtime/contrib/msc/*.cc") - set_source_files_properties(${MSC_RUNTIME_SOURCE} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") - list(APPEND RUNTIME_SRCS ${MSC_RUNTIME_SOURCE}) - - if(USE_TENSORRT_RUNTIME) - add_definitions("-DTENSORRT_ROOT_DIR=\"${TENSORRT_ROOT_DIR}\"") - endif() - - message(STATUS "Build with MSC support...") -endif() diff --git a/pyproject.toml b/pyproject.toml index e4241426d145..062f357bfa8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,7 +230,6 @@ unfixable = [] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401", "F403", "F405"] -"python/tvm/contrib/msc/**" = ["UP"] [tool.ruff.lint.isort] known-first-party = ["tvm"] diff --git a/python/tvm/contrib/msc/__init__.py b/python/tvm/contrib/msc/__init__.py deleted file mode 100644 index a2813b4a2dca..000000000000 --- a/python/tvm/contrib/msc/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc""" diff --git a/python/tvm/contrib/msc/core/__init__.py b/python/tvm/contrib/msc/core/__init__.py deleted file mode 100644 index 6d1a7c68c86d..000000000000 --- a/python/tvm/contrib/msc/core/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core""" diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py deleted file mode 100644 index ff027a0dec8e..000000000000 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.core", __name__) diff --git a/python/tvm/contrib/msc/core/codegen/__init__.py b/python/tvm/contrib/msc/core/codegen/__init__.py deleted file mode 100644 index 78da1b3fdd69..000000000000 --- a/python/tvm/contrib/msc/core/codegen/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.codegen""" - -from .codegen import * -from .sources import * diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py deleted file mode 100644 index fec642dad224..000000000000 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ /dev/null @@ -1,214 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501, RUF005 -"""tvm.contrib.msc.core.codegen.codegen""" - -import os -import subprocess -from typing import Any, Callable, Dict, List, Optional - -import tvm -from tvm import relax -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor -from tvm.relax import PyExprVisitor - - -class CodeGen: - """Manager class to generate codes and load model - - Parameters - ---------- - graph: MSCGraph - The reference graph for codegen. - source_getter: Callable - The method to get sources. - codegen_config: dict - The config to generate code. - print_config: dict - The config to print code. - build_folder: MSCDirectory - The codegen folder. - coda_format: str - The code format cpp| python. - """ - - def __init__( - self, - graph: MSCGraph, - source_getter: Callable[[MSCGraph, str, str], str], - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - code_format: str = "python", - ): - self._graph = graph - self._source_getter = source_getter - self._codegen_config = msc_utils.dump_dict(codegen_config) - self._print_config = msc_utils.dump_dict(print_config) - self._build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True) - self._code_format = code_format - - def load( - self, - inputs: Optional[List[Any]] = None, - pre_load: Optional[Callable[[msc_utils.MSCDirectory], Any]] = None, - post_load: Optional[Callable[[Any, msc_utils.MSCDirectory], Any]] = None, - build_model: bool = True, - ) -> Any: - """Generate source and load the model - - Parameters - ------- - inputs: list - The inputs to build the model. - pre_load: Callable - The pre processing method before load. - post_load: Callable - The post processing method after load. - build_model: bool - Whether to build the model. - - Returns - ------- - obj: model object - The model object for the framework. - """ - - sources = self._source_getter(self._graph, self._codegen_config, self._print_config) - inputs = inputs or [] - with self._build_folder as folder: - # pre processing - if pre_load: - pre_load(folder) - for name, source in sources.items(): - folder.add_file(name, source) - if build_model: - if self._code_format == "cpp": - with folder.create_dir("build"): - command = f"cmake ../ && make && mv {self._graph.name} ../" - with open("codegen.log", "w") as log_f: - process = subprocess.Popen( - command, stdout=log_f, stderr=log_f, shell=True - ) - process.wait() - assert process.returncode == 0, ( - f"Failed to build {self._graph.name} under {os.getcwd()}, check codegen.log for detail" - ) - obj = self._graph.name - elif self._code_format == "python": - builder = msc_utils.load_callable(self._graph.name + ".py:" + self._graph.name) - obj = builder(*inputs) - else: - raise NotImplementedError(f"Code format {self._code_format} is not supported") - # post processing - if post_load: - obj = post_load(obj, folder) - else: - obj = None - return obj - - -def to_relax( - graph: MSCGraph, - weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, - use_alias: bool = True, -) -> tvm.IRModule: - """Change MSCGraph to IRModule. - - Parameters - ---------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The parameters of the IRModule. - codegen_config: dict - The config for codegen. - print_config: dict - The config for print. - build_folder: MSCDirectory - The folder for saving scripts and datas. - plugin: PluginManager - The plugin manager. - use_alias: bool - Whether to use alias for input. - - Returns - ------- - mod: IRModule - The IRModule of relax. - """ - - @relax.expr_functor.visitor - class NamesGetter(PyExprVisitor): - """Visitor for get attributes in span""" - - def get_names(self, expr: relax.Expr) -> dict: - self._names = {} - if isinstance(expr, relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, relax.BindingBlock): - self.visit_binding_block(expr) - return self._names - - def visit_var_binding_(self, binding: relax.VarBinding) -> None: - super().visit_var_binding_(binding) - self._names[binding.var.name_hint] = binding.var.name_hint - - def _to_var(tensor: MSCTensor): - v_name = tensor.alias if use_alias else graph.find_producer(tensor).name - dims = [ - d if isinstance(d, int) else tvm.tir.Var(d, "int64") for d in tensor.get_shape(True) - ] - return tvm.relax.Var(v_name, tvm.relax.TensorStructInfo(dims, tensor.dtype_name)) - - def _save_weights(folder: msc_utils.MSCDirectory): - if weights: - with open(folder.relpath(graph.name + "_params.bin"), "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(weights)) - - # pylint: disable=unused-argument - def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - passes, var_names = [], NamesGetter().get_names(mod["main"]) - if weights: - passes.append(msc_transform.BindNamedParams("main", weights)) - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - passes.extend( - [ - msc_transform.SetExprName(var_names=var_names), - tvm.relax.transform.CanonicalizeBindings(), - tvm.relax.transform.ConvertToDataflow(min_size=1), - ] - ) - return tvm.ir.transform.Sequential( - passes, name="tvm.contrib.msc.core.codegen.to_relax_postproc" - )(mod) - - source_getter = tvm.get_global_func("msc.framework.tvm.GetRelaxSources") - codegen = CodeGen(graph, source_getter, codegen_config, print_config, build_folder) - model_args = [_to_var(i) for i in graph.get_inputs()] - if plugin: - model_args = model_args + [plugin] - return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc) diff --git a/python/tvm/contrib/msc/core/codegen/sources.py b/python/tvm/contrib/msc/core/codegen/sources.py deleted file mode 100644 index 825ec390f895..000000000000 --- a/python/tvm/contrib/msc/core/codegen/sources.py +++ /dev/null @@ -1,219 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.codegen.sources""" - -from typing import Dict - - -def get_base_h_code() -> str: - """Create base header file codes - - Returns - ------- - source: str - The base header source. - """ - - return """#ifndef TVM_CONTRIB_MSC_UTILS_BASE_H_ -#define TVM_CONTRIB_MSC_UTILS_BASE_H_ - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -class CommonUtils { - public: - template - static bool CompareBuffers(const T* golden, const T* result, size_t size) { - return true; - } -}; - -class FileUtils { - public: - static bool FileExist(const std::string& file); - - template - static bool ReadToBuffer(const std::string& file, T* buffer, size_t size) { - std::ifstream in_file(file, std::ifstream::binary); - if (!in_file.is_open()) { - return false; - } - try { - in_file.read((char*)(&buffer[0]), size * sizeof(T)); - } catch (std::exception const& e) { - in_file.close(); - return false; - } - in_file.close(); - return true; - } -}; - -class DatasetReader { - public: - DatasetReader(const std::string& folder, int max_size = -1); - - void Reset(); - - bool ReadNext(void* buffers[], int num_datas = -1); - - const std::vector GetTensorNames() { return tensor_names_; } - - size_t GetTensorSize(const std::string& name); - - const std::string GetSaveName(const std::string& name); - - private: - std::string folder_; - size_t max_size_; - size_t cur_cnt_; - std::vector tensor_names_; - std::unordered_map save_names_; - std::unordered_map tensor_sizes_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_UTILS_BASE_H_ -""" - - -def get_base_cc_code() -> str: - """Create base cc file codes - - Returns - ------- - source: str - The base cc source. - """ - - return """#include "base.h" - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -bool FileUtils::FileExist(const std::string& file) { - std::ifstream in_file(file, std::ifstream::binary); - if (in_file.is_open()) { - in_file.close(); - return true; - } - return false; -} - -DatasetReader::DatasetReader(const std::string& folder, int max_size) { - folder_ = folder; - const std::string info_file = folder_ + "/datas_info.txt"; - std::ifstream input(info_file, std::ios::binary); - assert(input.is_open() && ("Failed to open file " + info_file).c_str()); - std::string line; - while (getline(input, line)) { - // define name - int pos = line.find(" "); - assert(pos > 0 && ("Can not find space in line " + line).c_str()); - const auto& name = line.substr(0, pos); - tensor_names_.push_back(name); - const auto& left = line.substr(pos + 1, line.size()); - // define save_name - pos = left.find(" "); - assert(pos > 0 && ("Can not find space in left " + left).c_str()); - save_names_[name] = left.substr(0, pos); - // define size - const auto& byte_size = left.substr(pos + 1, left.size()); - tensor_sizes_[name] = static_cast(std::stoi(byte_size)); - } - size_t file_cnt = 0; - while (true) { - bool all_exists = true; - for (const auto& pair : save_names_) { - const auto& d_file = - folder_ + "/" + pair.second + "/batch_" + std::to_string(file_cnt) + ".bin"; - if (!FileUtils::FileExist(d_file)) { - all_exists = false; - break; - } - } - if (!all_exists) { - break; - } - file_cnt++; - } - max_size_ = max_size > 0 ? static_cast(max_size) : file_cnt; - max_size_ = std::min(max_size_, file_cnt); - Reset(); -} - -void DatasetReader::Reset() { cur_cnt_ = 0; } - -bool DatasetReader::ReadNext(void* buffers[], int num_datas) { - if (cur_cnt_ >= max_size_) { - return false; - } - size_t max_num = num_datas > 0 ? static_cast(num_datas) : tensor_names_.size(); - max_num = std::min(max_num, tensor_names_.size()); - for (size_t i = 0; i < max_num; i++) { - const auto& name = tensor_names_[i]; - const auto& d_file = - folder_ + "/" + GetSaveName(name) + "/batch_" + std::to_string(cur_cnt_) + ".bin"; - if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], GetTensorSize(name))) { - return false; - } - } - cur_cnt_++; - return true; -} - -size_t DatasetReader::GetTensorSize(const std::string& name) { - assert(tensor_sizes_.count(name)); - return tensor_sizes_[name]; -} - -const std::string DatasetReader::GetSaveName(const std::string& name) { - assert(save_names_.count(name)); - return save_names_[name]; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm -""" - - -def get_base_sources() -> Dict[str, str]: - """Create base sources for cpp codegen - - Returns - ------- - sources: dict - The base utils sources. - """ - - return {"base.h": get_base_h_code(), "base.cc": get_base_cc_code()} diff --git a/python/tvm/contrib/msc/core/frontend/__init__.py b/python/tvm/contrib/msc/core/frontend/__init__.py deleted file mode 100644 index a5fd7a01a803..000000000000 --- a/python/tvm/contrib/msc/core/frontend/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.frontend""" - -from .translate import * diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py deleted file mode 100644 index 1bceff8818cc..000000000000 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ /dev/null @@ -1,270 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E741 -"""tvm.contrib.msc.core.frontend.translate""" - -from typing import Dict, List, Optional, Tuple - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor -from tvm.relax import PyExprVisitor -from tvm.relax.backend.pattern_registry import get_patterns_with_prefix -from tvm.relax.transform import BindParams - - -def normalize_inputs(inputs: List[tuple]) -> List[tuple]: - """Normalize the inputs info - - Parameters - ---------- - inputs: list of - The inputs info. - - Returns - ------- - inputs: list of - The normalized inputs info. - """ - - recorded_vars = {} - - def _normalize_input(inp): - def _normalize(info): - if not isinstance(info, (tuple, list)): - return info - dims = [] - for dim in info: - if isinstance(dim, int): - dims.append(dim) - elif dim in recorded_vars: - dims.append(recorded_vars[dim]) - elif isinstance(dim, str): - recorded_vars[dim] = tvm.tir.Var(dim, "int64") - dims.append(recorded_vars[dim]) - else: - raise TypeError(f"Unexpected dim {dim} in shape {info}") - return dims - - return [_normalize(i) for i in inp] - - return [_normalize_input(inp) for inp in inputs] - - -def normalize_weights( - t_weights: Dict[MSCTensor, tvm.runtime.Tensor], graph: MSCGraph -) -> Dict[str, tvm.runtime.Tensor]: - """Normalize the weghts. - - Parameters - ---------- - t_weights: dict of - The weights extracted from IRModule. - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - - Returns - ------- - weights: dict of - The normalized weights. - """ - - def _to_data(ref_t, data): - weight_t = graph.find_tensor(ref_t.name) - if weight_t.ndim == 1: - if ref_t.ndim != weight_t.ndim: - return tvm.runtime.tensor(data.numpy().reshape(weight_t.get_shape())) - return data - if ref_t.layout and weight_t.layout: - ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name - if ref_layout != weight_layout: - assert all(l in ref_layout for l in weight_layout), ( - f"layout mismatch {ref_t} compare to {weight_t}" - ) - permute = [ref_layout.index(l) for l in weight_layout] - return tvm.runtime.tensor(data.numpy().transpose(*permute)) - return data - - weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)} - # sort the weights by graph weights - graph_weights = {} - for weight in graph.get_weights(): - assert weight.name in weights, "Missing weight " + str(weight) - graph_weights[weight.name] = weights[weight.name] - return graph_weights - - -def from_relax( - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.runtime.Tensor]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, -) -> Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]: - """Change IRModule to MSCGraph. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - params: dict of - The parameters of the IRModule. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize the relax before translate. - - Returns - ------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The weights from the IRModule. - """ - - trans_config = msc_utils.copy_dict(trans_config) - build_config = msc_utils.copy_dict(build_config) - opt_config = msc_utils.copy_dict(opt_config) - entry = trans_config.get("entry", "main") - if params: - mod = BindParams("main", params)(mod) - opt_level = opt_config.get("opt_level", 1) - if opt_level > 0: - mod = tvm.transform.Sequential( - [ - tvm.relax.transform.FoldConstant(), - ] - )(mod) - patterns = get_patterns_with_prefix("msc.") - passes = [ - tvm.relax.transform.ExpandTupleArguments(), - msc_transform.SetExprName(), - msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), - tvm.relax.transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=False - ), - ] - mod = tvm.transform.Sequential(passes)(mod) - graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config)) - t_weights = _ffi_api.GetRelaxWeights(mod, entry) - return graph, normalize_weights(t_weights, graph) - - -@tvm.relax.expr_functor.visitor -class BYOCChecker(PyExprVisitor): - """Checker to check if any non-target ops exist""" - - def check(self, func_names, expr): - self._func_names = func_names - self._non_target_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._non_target_exprs) == 0, f"Some exprs not on target {expr}" - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - if isinstance(binding.value, tvm.relax.Call): - if isinstance(binding.value.op, tvm.relax.GlobalVar): - if binding.value.op.name_hint not in self._func_names: - self._non_target_exprs.append(binding.value) - else: - self._non_target_exprs.append(binding.value) - elif not isinstance(binding.value, tvm.relax.DataflowVar): - self._non_target_exprs.append(binding.value) - - -def byoc_partition( - target: str, - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.runtime.Tensor]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, -) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]]]: - """Partition module to target sub functions. - - Parameters - ---------- - target: str - The target for the BYOC. - mod: IRModule - The IRModule of relax. - trans_config: dict - The config for transform IRModule. - params: dict of - The parameters of the IRModule. - build_config: dict - The config for build MSCGraph. - - Returns - ------- - mod: IRModule - The IRModule of partitioned relax. - graphs_info: list<> - The func list, each element for a sub graph. - """ - - trans_config = msc_utils.copy_dict(trans_config) - build_config = msc_utils.copy_dict(build_config) - build_config["target"] = target - for key in ["input_aliases", "output_aliases"]: - if key in build_config: - build_config.pop(key) - entry = trans_config.get("entry", "main") - if params: - mod = BindParams("main", params)(mod) - - def _partition_mod(mod, as_msc=True): - patterns = get_patterns_with_prefix(target) - passes = [ - tvm.relax.transform.ExpandTupleArguments(), - msc_transform.SetExprName(), - msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), - tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc), - msc_transform.InlineParams(), - msc_transform.FuseTuple(target), - tvm.relax.transform.MergeCompositeFunctions(), - msc_transform.SetBYOCAttrs(target), - ] - return tvm.transform.Sequential(passes)(mod) - - def _is_target_func(func): - if "Codegen" not in func.attrs: - return False - return func.attrs["Codegen"] == target - - msc_mod = _partition_mod(mod) - func_names = [var.name_hint for var, func in msc_mod.functions.items() if _is_target_func(func)] - - if trans_config.get("as_complete", True): - assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod) - BYOCChecker().check(func_names, msc_mod[entry]) - - ref_weights = _ffi_api.GetRelaxWeights(msc_mod, entry) - graphs, weights = [], {} - for name in func_names: - graph_name = msc_mod[name].attrs[_ffi_api.ToAttrKey("unique")] - build_config.update({"graph_name": graph_name, "byoc_entry": name}) - graph = _ffi_api.BuildFromRelax(msc_mod, entry, msc_utils.dump_dict(build_config)) - graphs.append(graph) - weights.update(normalize_weights(ref_weights, graph)) - return _partition_mod(mod, False), graphs, weights diff --git a/python/tvm/contrib/msc/core/gym/__init__.py b/python/tvm/contrib/msc/core/gym/__init__.py deleted file mode 100644 index 7c75dcc56a48..000000000000 --- a/python/tvm/contrib/msc/core/gym/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym""" - -from .environment import * -from .agent import * -from .control import * diff --git a/python/tvm/contrib/msc/core/gym/agent/__init__.py b/python/tvm/contrib/msc/core/gym/agent/__init__.py deleted file mode 100644 index e71ba5d7fbad..000000000000 --- a/python/tvm/contrib/msc/core/gym/agent/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.agent""" - -from .method import * -from .search_agent import * diff --git a/python/tvm/contrib/msc/core/gym/agent/base_agent.py b/python/tvm/contrib/msc/core/gym/agent/base_agent.py deleted file mode 100644 index ae49dd9143b8..000000000000 --- a/python/tvm/contrib/msc/core/gym/agent/base_agent.py +++ /dev/null @@ -1,324 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.base_agent""" - -import copy -import logging -from typing import Any, Dict, List, Optional, Tuple - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMObject - - -class BaseAgent: - """Basic Agent of MSC.Gym - - Parameters - ---------- - name: str - The name of agent. - workspace: MSCDirectory - The worksapce. - executors: dict - The executors of the agent. - options: dict - The extra options for the agent. - debug_level: int - The debug level. - logger: logging.Logger - The logger - """ - - def __init__( - self, - name: str, - workspace: msc_utils.MSCDirectory, - executors: dict, - options: Optional[dict] = None, - debug_level: int = 0, - logger: Optional[logging.Logger] = None, - ): - self._name = name - self._workspace = workspace - self._executors = self._parse_executors(msc_utils.copy_dict(executors)) - self._options = options or {} - self._debug_level = debug_level - self._logger = logger or msc_utils.get_global_logger() - self._logger.info(msc_utils.msg_block(self.agent_mark("SETUP"), self.setup())) - - def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: - """Parse the executors - - Parameters - ---------- - executors_dict: dict - The given executors. - - Returns - ------- - executors_dict: dict - The parsed executors. - """ - - executors = {} - for name, raw_config in executors_dict.items(): - method_type = ( - raw_config.pop("method_type") if "method_type" in raw_config else "default" - ) - method_cls = msc_utils.get_registered_gym_method(GYMObject.AGENT, method_type) - assert method_cls, f"Can not find method cls for {GYMObject.AGENT}:{method_type}" - assert "method" in raw_config, "method should be given to find agent method" - method_name, method = raw_config.pop("method"), None - if hasattr(method_cls, method_name): - method = getattr(method_cls, method_name) - if not method: - method = msc_utils.get_registered_func(method_name) - assert method, "Can not find method " + str(method_name) - executors[name] = (method_name, method, copy.deepcopy(raw_config)) - return executors - - def setup(self) -> dict: - """Setup the agent - - Returns - ------- - info: dict - The setup info. - """ - - self._knowledge = {"observations": [], "actions": [], "rewards": []} - return { - "name": self._name, - "workspace": self._workspace, - "executors": {k: f"{v[0]}({v[2]})" for k, v in self._executors.items()}, - "options": self._options, - "debug_level": self._debug_level, - } - - def init(self, max_task: int, baseline: Dict[str, Any]): - """Init the agent - - Parameters - ---------- - max_task: int - The max task for agent. - baseline: dict - The baseline of environment. - """ - - self._max_task = max_task - self._baseline = baseline - - def reset(self): - """Reset the agent""" - - self._knowledge = {"observations": [], "actions": [], "rewards": []} - - def choose_action(self, task_id: int, observation: Any, action_space: List[dict]) -> List[dict]: - """Choose action based on observation - - Parameters - ---------- - task_id: int - The current task id. - observation: - The current observation. - action_space: list - The possible action space - - Returns - ------- - actions: list - The actions for next task. - """ - - actions = self._choose_action(task_id, observation, action_space) - if task_id == len(self._knowledge["observations"]): - self._knowledge["observations"].append(observation) - self._knowledge["actions"].append(actions) - elif task_id == len(self._knowledge["observations"]) - 1: - self._knowledge["actions"][-1].extend(actions) - else: - raise TypeError( - "Step id should be either {0} or {0}-1, get {1}".format( - len(self._knowledge["observations"]), task_id - ) - ) - return actions - - def _choose_action( - self, task_id: int, observation: Any, action_space: List[dict] - ) -> List[dict]: - """Choose action based on observation - - Parameters - ---------- - task_id: int - The current task id. - observation: - The current observation. - action_space: list - The possible action space - - Returns - ------- - actions: list - The actions for next task. - """ - - raise NotImplementedError("_choose_action is not implemented in BaseAgent") - - def store(self, task_id: int, rewards: List[dict]) -> int: - """Store rewards - - Parameters - ---------- - task_id: int - The current task id. - rewards: list - The rewards for each action - - Returns - ------- - next_task: int - The next task id. - """ - - if task_id == len(self._knowledge["rewards"]): - self._knowledge["rewards"].append(rewards) - elif task_id == len(self._knowledge["rewards"]) - 1: - self._knowledge["rewards"][-1].extend(rewards) - else: - raise TypeError( - "Step id should be either {0} or {0}-1, get {1}".format( - len(self._knowledge["rewards"]), task_id - ) - ) - return self._store(task_id) - - def _store(self, task_id: int): - """Store rewards - - Parameters - ---------- - task_id: int - The current task id. - - Returns - ------- - next_task: int - The next task id. - """ - - return task_id + 1 - - def learn(self): - """Learn from knowledge - - Returns - ------- - actions: list - The learned actions. - rewards: list - The learned rewards. - """ - - self._logger.debug(msc_utils.msg_block(self.agent_mark("KNOWLEDEG"), self._knowledge)) - return self._learn() - - def _learn(self): - """Learn from knowledge - - Returns - ------- - actions: list - The learned actions. - rewards: list - The learned rewards. - """ - - raise NotImplementedError("_learn is not implemented in BaseAgent") - - def destory(self): - """Destory the agent""" - - return None - - def _execute(self, name: str, *args, **kwargs) -> Any: - """Run executor - - Parameters - ---------- - name: str - The executor name. - args: list - The arguments for execute. - kwargs: dict - The key word arguments for execute. - - Returns - ------- - res: - The execute result. - """ - - assert name in self._executors, ( - f"Can not find {name} in executors: {self._executors.keys()}" - ) - _, method, config = self._executors[name] - kwargs.update({k: v for k, v in config.items() if k not in kwargs}) - return method(self, *args, **kwargs) - - def _evaluate(self, reward: dict) -> float: - """Evaluate a reward with baseline - - Parameters - ---------- - reward: dict - The reward for. - - Returns - ------- - score: float - The score of the reward. - """ - - return self._execute("evaluate", self._baseline, reward) - - def agent_mark(self, msg: Any) -> str: - """Mark the message with agent info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"AGENT({self.role_type()}) {msg}" - - @classmethod - def role(cls): - return GYMObject.AGENT - - @classmethod - def role_type(cls): - return "base" diff --git a/python/tvm/contrib/msc/core/gym/agent/method.py b/python/tvm/contrib/msc/core/gym/agent/method.py deleted file mode 100644 index f0c8b8a8917b..000000000000 --- a/python/tvm/contrib/msc/core/gym/agent/method.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.gym.agent.method""" - -from typing import Any - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMObject - - -@msc_utils.register_gym_method -class AgentMethod: - """Default prune method""" - - @classmethod - def evaluate_by_loss(cls, agent: Any, baseline: dict, reward: dict) -> float: - """Evaluate the raw loss - - Parameters - ---------- - agent: BaseAgent - The base agent. - baseline: dict - The baseline. - reward: dict - The reward. - - Returns - ------- - score: float - The score. - """ - - assert "loss" in reward, "loss should be given to evaluate loss" - return 1 / reward["loss"] - - @classmethod - def evaluate_by_thresh(cls, agent: Any, baseline: dict, reward: dict, thresh: float) -> float: - """Evaluate the raw loss - - Parameters - ---------- - agent: BaseAgent - The base agent. - baseline: dict - The baseline. - reward: dict - The reward. - thresh: float - The threshold - - Returns - ------- - score: float - The score. - """ - - assert "reward" in reward, "reward should be given to evaluate threshold" - if reward["reward"] >= thresh: - return thresh - return reward["reward"] - - @classmethod - def role(cls): - return GYMObject.AGENT - - @classmethod - def method_type(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/gym/agent/search_agent.py b/python/tvm/contrib/msc/core/gym/agent/search_agent.py deleted file mode 100644 index 6554b0563392..000000000000 --- a/python/tvm/contrib/msc/core/gym/agent/search_agent.py +++ /dev/null @@ -1,181 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.search_agent""" - -from typing import Any, List - -from tvm.contrib.msc.core import utils as msc_utils - -from .base_agent import BaseAgent - - -class BaseSearchAgent(BaseAgent): - """Base Search Agent of MSC.Gym""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - self._max_search = self._options.get("max_search", -1) - return super().setup() - - @classmethod - def role_type(cls): - return "search.base" - - -@msc_utils.register_gym_object -class GridSearchAgent(BaseSearchAgent): - """GridSearch agent""" - - def _choose_action( - self, task_id: int, observation: Any, action_space: List[dict] - ) -> List[dict]: - """Choose action based on observation - - Parameters - ---------- - task_id: int - The current task id. - observation: - The current observation. - action_space: list - The possible action space - - Returns - ------- - actions: list - The actions for next task. - """ - - return action_space - - def _learn(self): - """Learn from knowledge - - Returns - ------- - actions: list - The learned actions. - rewards: list - The learned rewards. - """ - - best_actions = [None] * len(self._knowledge["actions"]) - best_rewards = [None] * len(self._knowledge["rewards"]) - idx = 0 - for actions, rewards in zip(self._knowledge["actions"], self._knowledge["rewards"]): - best_score = None - for action, reward in zip(actions, rewards): - score = self._evaluate(reward) - if best_score is None or score > best_score: - best_actions[idx] = action - best_rewards[idx] = reward - best_score = score - idx += 1 - return best_actions, best_rewards - - @classmethod - def role_type(cls): - return "search.grid" - - -@msc_utils.register_gym_object -class BinarySearchAgent(BaseSearchAgent): - """BinarySearch agent""" - - def reset(self): - """Reset the agent""" - - self._ranges = [{"start": 0, "end": -1} for _ in range(self._max_task)] - super().reset() - - def _choose_action( - self, task_id: int, observation: Any, action_space: List[dict] - ) -> List[dict]: - """Choose action based on observation - - Parameters - ---------- - task_id: int - The current task id. - observation: - The current observation. - action_space: list - The possible action space - - Returns - ------- - actions: list - The actions for next task. - """ - - if self._ranges[task_id]["end"] == -1: - self._ranges[task_id]["end"] = len(action_space) - return [action_space[self._ranges[task_id]["start"]]] - pos = (self._ranges[task_id]["start"] + self._ranges[task_id]["end"]) / 2 - return [action_space[pos]] - - def _store(self, task_id: int): - """Store rewards - - Parameters - ---------- - task_id: int - The current task id. - - Returns - ------- - next_task: int - The next task id. - """ - - rewards = self._knowledge["rewards"][task_id] - start = self._ranges[task_id]["start"] - end = self._ranges[task_id]["end"] - if len(rewards) > 1: - if self._evaluate(rewards[-1]) > self._evaluate(rewards[-2]): - self._ranges[task_id]["end"] = (start + end) // 2 - else: - self._ranges[task_id]["start"] = (start + end) // 2 - if start - end <= 1: - return task_id + 1 - return task_id - - def _learn(self): - """Learn from knowledge - - Returns - ------- - actions: list - The learned actions. - rewards: list - The learned rewards. - """ - - actions = [a[-1] for a in self._knowledge["actions"]] - rewards = [r[-1] for r in self._knowledge["rewards"]] - return actions, rewards - - @classmethod - def role_type(cls): - return "search.binary" diff --git a/python/tvm/contrib/msc/core/gym/control/__init__.py b/python/tvm/contrib/msc/core/gym/control/__init__.py deleted file mode 100644 index af6fe592b9b8..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.control""" - -from .controller import * -from .configer import * diff --git a/python/tvm/contrib/msc/core/gym/control/configer.py b/python/tvm/contrib/msc/core/gym/control/configer.py deleted file mode 100644 index 9a39ff560986..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/configer.py +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.configer""" - -from tvm.contrib.msc.core import utils as msc_utils - - -class BaseConfiger: - """Configer for Gym - - Parameters - ---------- - stage: str - The stage for gym, should be in MSCStage. - """ - - def __init__(self, stage: str): - self._stage = stage - - def update(self, raw_config: dict) -> dict: - """Config the raw config - - Parameters - ---------- - raw_config: dict - The raw config. - - Returns - ------- - config: dict - The update config. - """ - - raise NotImplementedError("update is not implemented in BaseConfiger") - - -@msc_utils.register_gym_configer -class DefaultConfiger(BaseConfiger): - """Default configer for gym""" - - def update(self, raw_config: dict) -> dict: - """Config the raw config - - Parameters - ---------- - raw_config: dict - The raw config. - - Returns - ------- - config: dict - The update config. - """ - - config = msc_utils.copy_dict(raw_config) - assert "env" in config and "agent" in config, "env and agent should be given to run gym" - if "role_type" not in config["env"]: - config["env"]["role_type"] = self._stage + ".default" - if "role_type" not in config["agent"]: - config["agent"]["role_type"] = "search.grid" - if "executors" not in config["env"]: - config["env"]["executors"] = {} - # update executors - env_executors = { - "reward_runner": {"method": "reward_compare_baseline"}, - "create_tasks": {"method": "tasks_tool_extract"}, - } - config["env"]["executors"].update( - {k: v for k, v in env_executors.items() if k not in config["env"]["executors"]} - ) - if "executors" not in config["agent"]: - config["agent"]["executors"] = {} - agent_executors = {"evaluate": {"method": "evaluate_by_loss"}} - config["agent"]["executors"].update( - {k: v for k, v in agent_executors.items() if k not in config["agent"]["executors"]} - ) - return config - - @classmethod - def config_type(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/gym/control/controller.py b/python/tvm/contrib/msc/core/gym/control/controller.py deleted file mode 100644 index d1b1f725fb79..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/controller.py +++ /dev/null @@ -1,107 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.control.controller""" - -from typing import Any, Dict, Optional - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMAction, GYMObject - -from .service import MainService, NodeService - - -class BaseController: - """Basic controller for optimize search - - Parameters - ---------- - workspace: MSCDirectory - The worksapce. - config: dict - The config for service. - is_main: bool - Whether the node is main node - """ - - def __init__( - self, - workspace: msc_utils.MSCDirectory, - config: Dict[str, Any], - is_main: bool = True, - ): - self._workspace = workspace - service_cls = MainService if is_main else NodeService - self._service = service_cls(self._workspace, **config) - - def run(self) -> dict: - """Run the controller - - Returns - ------- - report: dict - The run report. - """ - - self._service.init() - while not self._service.done: - self._service.reset() - while not self._service.iter_done: - self._service.execute(GYMObject.ENV, GYMAction.GET_STATE) - self._service.execute(GYMObject.AGENT, GYMAction.CHOOSE_ACTION) - self._service.execute(GYMObject.ENV, GYMAction.STEP) - self._service.execute(GYMObject.AGENT, GYMAction.STORE) - self._service.learn() - return self._service.summary() - - -def create_controller(stage: str, config: dict, extra_config: Optional[dict] = None): - """Update the gym config - - Parameters - ---------- - stage: str - The stage for gym, should be in MSCStage. - config: dict - The raw config. - extra_config: dict - The extra config - - Returns - ------- - config: dict - The update config. - """ - - config_type = config.pop("config_type") if "config_type" in config else "default" - configer_cls = msc_utils.get_registered_gym_configer(config_type) - assert configer_cls, "Can not find configer for " + str(config_type) - config = configer_cls(stage).update(config) - if extra_config: - config = msc_utils.update_dict(config, extra_config) - if "control_type" in config: - control_type = config.pop("control_type") - else: - control_type = "default" - controller_cls = msc_utils.get_registered_gym_controller(control_type) - return controller_cls(msc_utils.get_gym_dir(), config) - - -@msc_utils.register_gym_controller -class DefaultController(BaseController): - @classmethod - def control_type(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/gym/control/namespace.py b/python/tvm/contrib/msc/core/gym/control/namespace.py deleted file mode 100644 index 606ab3410f7e..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/namespace.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.control.namespace""" - - -class GYMObject: - """Enum all gym objects""" - - BASE = "base" - ENV = "env" - AGENT = "agent" - SERVICE = "service" - - -class GYMAction: - """Enum all gym actions""" - - INIT = "init" - RESET = "reset" - GET_STATE = "get_state" - CHOOSE_ACTION = "choose_action" - STEP = "step" - STORE = "store" - LEARN = "learn" - SUMMARY = "summary" - CLEANUP = "cleanup" diff --git a/python/tvm/contrib/msc/core/gym/control/service.py b/python/tvm/contrib/msc/core/gym/control/service.py deleted file mode 100644 index 6deed1ab18fd..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/service.py +++ /dev/null @@ -1,827 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501 -"""tvm.contrib.msc.core.gym.control.service""" - -import copy -import json -import queue -import time -from functools import partial, reduce -from multiprocessing import Manager -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMAction, GYMObject - -from .worker import BaseGymWorker, WorkerFactory - - -def _send_message(msg_queue: queue.Queue, header: str, body: dict, header_type: str = "message"): - """Send the message to queue - - Parameters - ---------- - msg_queue: Queue - The message queue. - header: str - The header of message. - body: dict - The message body. - header_type: str - The header type - """ - - msg_queue.put(json.dumps({header_type: header, "body": body})) - - -def _wait_message( - msg_queue: queue.Queue, - header: str, - checker: Optional[callable] = None, - wait_time: int = 2, - max_retry: int = -1, - header_type: str = "message", -) -> dict: - """Wait until valid message - - Parameters - ---------- - msg_queue: Queue - The message queue. - header: str - The header of message. - checker: callable - The checker for the message. - wait_time: int - The wait time between retry in second. - max_retry: int - The max retry time. - header_type: str - The header type - - Returns - ------- - message: dict - The message body - """ - - def _check_message(message: dict, checker: Optional[callable] = None) -> bool: - """Check the message - - Parameters - ---------- - message: dict - The message. - checker: callable - The checker for the message. - - Returns - ------- - pass: bool - Whether the message pass. - """ - - if "body" not in message: - return False - if checker and not checker(message["body"]): - return False - return True - - try_cnt = 0 - while True: - if try_cnt >= max_retry > 0: - break - info = msg_queue.get() - message = json.loads(info) - if message.get(header_type, "") == header and _check_message(message, checker): - return message["body"] - try_cnt += 1 - msg_queue.put(info) - time.sleep(wait_time) - return None - - -send_request = partial(_send_message, header_type="request_header") -send_response = partial(_send_message, header_type="response_header") -wait_request = partial(_wait_message, header_type="request_header") -wait_response = partial(_wait_message, header_type="response_header") - - -class GatherMode: - """Enum all gather mode""" - - PARALLEL = "parallel" - REDUCE_SUM = "reduce_sum" - REDUCE_MEAN = "reduce_mean" - FIRST = "first" - - -class BaseService: - """Basic service for gym - - Parameters - ---------- - workspace: MSCDirectory - The worksapce. - env: dict - The environment config. - agent: dict - The agent config - tasks: list - The tasks on the node. - world_size: int - The world size. - max_iter: int - The max seatch iter. - record_step: int - The record step. - verbose: str - The verbose level - """ - - def __init__( - self, - workspace: msc_utils.MSCDirectory, - env: Dict[str, Any], - agent: Dict[str, Any], - tasks: Optional[List[str]] = None, - dist_manager: Optional[Manager] = None, - world_size: int = 1, - max_iter: int = 1, - record_step: int = 5, - debug_level: int = 0, - verbose: Optional[str] = None, - ): - self._workspace = workspace - tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"] - verbose = verbose or "info" - debug_level = int(verbose.split(":")[1]) if verbose.startswith("debug:") else 0 - self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG")) - - def _create_workers(config: dict, obj_type: str) -> List[BaseGymWorker]: - if "debug_level" not in config: - config["debug_level"] = debug_level - if "logger" not in config: - config["logger"] = self._logger - return [ - WorkerFactory.create(t, workspace, config) for t in tasks if t.startswith(obj_type) - ] - - self._env_workers = _create_workers(env, GYMObject.ENV) - self._agent_workers = _create_workers(agent, GYMObject.AGENT) - self._dist_manager = dist_manager - self._world_size = world_size - self._max_iter = max_iter - self._record_step = record_step - self._debug_level = debug_level - self._logger.info(msc_utils.msg_block(self.service_mark("SETUP"), self.setup())) - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - if self._world_size > 1: - assert self._dist_manager, "dist manager should be given for distributed service" - self._request_queue = self._dist_manager.get_request_queue() - self._response_queue = self._dist_manager.get_response_queue() - self._world_id, self._env_world_ids, self._agent_world_ids = self._connect() - else: - self._request_queue = queue.Queue() - self._response_queue = queue.Queue() - self._world_id = 0 - self._env_world_ids = [w.worker_id for w in self._env_workers] - self._agent_world_ids = [w.worker_id for w in self._agent_workers] - return { - "workspace": self._workspace, - "world_id": self._world_id, - "world_size": self._world_size, - "env_worker_ids": self._get_worker_ids(GYMObject.ENV), - "env_world_ids": self._env_world_ids, - "agent_worker_ids": self._get_worker_ids(GYMObject.AGENT), - "agent_world_ids": self._agent_world_ids, - "max_iter": self._max_iter, - "record_step": self._record_step, - "debug_level": self._debug_level, - } - - def init(self): - self._logger.info("SERVICE Init") - self._iter_id, self._done = 0, False - self._max_task = 0 - self._task_id, self._states = 0, [] - self._iter_done = False - self.execute(GYMObject.ENV, GYMAction.INIT) - self.execute(GYMObject.AGENT, GYMAction.INIT) - - def reset(self): - self._task_id, self._states = 0, [] - self._iter_done = False - self._logger.info("SERVICE Reset %d/%d th iter", self._iter_id, self._max_iter) - self.execute(GYMObject.ENV, GYMAction.RESET) - self.execute(GYMObject.AGENT, GYMAction.RESET) - - def learn(self): - self.execute(GYMObject.AGENT, GYMAction.LEARN) - if self._iter_done: - self._iter_id += 1 - if self._iter_id >= self._max_iter: - self._done = True - - def summary(self): - self._logger.info("SERVICE Summary after %d iters", self._max_iter) - self.execute(GYMObject.ENV, GYMAction.SUMMARY) - plan = self._states[-1]["response"]["plan"] - self.execute(GYMObject.ENV, GYMAction.CLEANUP) - self.execute(GYMObject.AGENT, GYMAction.CLEANUP) - return plan - - def execute(self, obj_type: str, act_type: str): - """Execute the service - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - """ - - self._states.append( - { - "task_id": self._task_id, - "msg_key": self._to_msg_key(obj_type, act_type), - "response": self._execute(obj_type, act_type), - } - ) - - def _execute(self, obj_type: str, act_type: str) -> dict: - """Execute the service - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - - Returns - ------- - state: dict - The state after the execute. - """ - - raise NotImplementedError("_execute is not implemented in BaseService") - - def _send_request(self, msg_key: str, body: dict): - """Send request - - Parameters - ---------- - msg_key: str - The header of message. - body: dict - The message body. - """ - - send_request(self._request_queue, msg_key, body) - - def _send_response(self, msg_key: str, body: dict): - """Send request - - Parameters - ---------- - msg_key: str - The header of message. - body: dict - The message body. - """ - - send_response(self._response_queue, msg_key, body) - - def _wait_request( - self, - msg_key: str, - checker: Optional[callable] = None, - wait_time: int = 2, - max_retry: int = -1, - ) -> dict: - """Wait request - - Parameters - ---------- - msg_key: str - The header of message. - checker: callable - The checker for the message. - wait_time: int - The wait time between retry in second. - max_retry: int - The max retry time. - """ - - return wait_request(self._request_queue, msg_key, checker, wait_time, max_retry) - - def _wait_response( - self, - msg_key: str, - checker: Optional[callable] = None, - wait_time: int = 2, - max_retry: int = -1, - ) -> dict: - """Wait response - - Parameters - ---------- - msg_key: str - The header of message. - checker: callable - The checker for the message. - wait_time: int - The wait time between retry in second. - max_retry: int - The max retry time. - """ - - return wait_request(self._response_queue, msg_key, checker, wait_time, max_retry) - - def _process_request(self, msg_key: str) -> dict: - """Process the request according to msg_key - - Parameters - ---------- - msg_key: str - The header of message. - - Returns - ------- - responses: dict - The responses of wrokers. - """ - - obj_type, act_type = self._from_msg_key(msg_key) - workers = {w.worker_id: w for w in self._get_workers(obj_type)} - requests = self._wait_request(msg_key) - if act_type in (GYMAction.INIT, GYMAction.RESET): - mark = f"Iter[{self._iter_id}/{self._max_iter}] {obj_type}.{act_type}" - else: - mark = f"Iter[{self._iter_id}/{self._max_iter}] Task[{self._task_id}/{self._max_task}] {obj_type}.{act_type}" - requests = {int(k): v for k, v in requests.items()} - responses = {} - for w_id, worker in workers.items(): - responses[w_id] = worker.execute(act_type, **requests[w_id]) - info = { - "requests": {workers[w].name: r for w, r in requests.items()}, - "responses": {workers[w].name: r for w, r in responses.items()}, - } - self._logger.info(msc_utils.msg_block(mark, info, symbol="=")) - return responses - - def _process_response(self, msg_key: str, response: dict): - """Update reponse - - Parameters - ---------- - msg_key: str - The header of message. - response: dict - The response. - - Returns - ------- - response: dict - The updated response. - """ - - obj_type, act_type = self._from_msg_key(msg_key) - if obj_type == GYMObject.ENV and act_type == GYMAction.INIT: - self._max_task = response["max_task"] - if obj_type == GYMObject.AGENT and act_type == GYMAction.STORE: - self._task_id = response["next_task"] - if self._task_id >= self._max_task: - self._iter_done = True - return response - - def _to_msg_key(self, obj_type: str, act_type: str) -> str: - """Create message key base on types - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - - Returns - ------- - key: str - The message key. - """ - - return f"{obj_type}-s-{act_type}" - - def _from_msg_key(self, msg_key: str) -> Tuple[str, str]: - """Get obj_type and act_type from message key - - Parameters - ---------- - msg_key: str - The message key. - - Returns - ------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - """ - - return msg_key.split("-s-") - - def _get_workers(self, obj_type: str) -> List[BaseGymWorker]: - """Get workers according to obj_type - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - - Returns - ------- - workers: list - The workers. - """ - - if obj_type == GYMObject.ENV: - return self._env_workers - if obj_type == GYMObject.AGENT: - return self._agent_workers - return [] - - def _get_worker_ids(self, obj_type: str) -> List[int]: - """Get worker ids according to obj_type - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - - Returns - ------- - worker_ids: list - The worker ids. - """ - - return [w.worker_id for w in self._get_workers(obj_type)] - - def _get_world_ids(self, obj_type: str) -> List[int]: - """Get world ids according to obj_type - - Parameters - obj_type: str - The object type, should be one of GYMObject. - - Returns - ------- - world_ids: list - The world ids. - """ - - if obj_type == GYMObject.ENV: - return self._env_world_ids - if obj_type == GYMObject.AGENT: - return self._agent_world_ids - return [] - - def service_mark(self, msg: Any) -> str: - """Mark the message with service info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"SERIVCE({self.service_type}) {msg}" - - @property - def done(self): - return self._done - - @property - def iter_done(self): - return self._iter_done - - @property - def service_type(self): - return "base" - - -class MainService(BaseService): - """Main service for gym""" - - def _connect(self): - msg_key = self._to_msg_key(GYMObject.SERVICE, GYMAction.SETUP) - env_world_ids = self._get_worker_ids(GYMObject.ENV) - agent_world_ids = self._get_worker_ids(GYMObject.AGENT) - # send world_id and get env/agent ids - barrier = self._world_size - 1 - - def _check_response(body): - return all(k in body for k in ["env_worker_ids", "agent_worker_ids"]) - - for i in range(barrier): - self._send_request(msg_key, {"world_id": i + 1}) - while barrier > 0: - info = self._wait_response(msg_key, _check_response) - if info: - env_world_ids.extend(info["env_world_ids"]) - agent_world_ids.extend(info["agent_world_ids"]) - barrier -= 1 - - self._synchronize_feedback( - msg_key, env_world_ids=env_world_ids, agent_world_ids=agent_world_ids - ) - return 0, env_world_ids, agent_world_ids - - def _execute(self, obj_type: str, act_type: str) -> dict: - """Execute the service - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - - Returns - ------- - state: dict - The state after the execute. - """ - - world_ids = self._get_worker_ids(obj_type) - tasks = {i: self._create_task(obj_type, act_type, i) for i in world_ids} - msg_key = self._to_msg_key(obj_type, act_type) - response = self._synchronize_request(msg_key, tasks) - response = self._process_response(msg_key, response) - self._synchronize_feedback(msg_key, **response) - return response - - def _synchronize_request( - self, - msg_key: str, - requests: List[dict], - checker: Optional[callable] = None, - wait_time: int = 2, - max_retry: int = -1, - ) -> dict: - """Send requests to workers and gather response - - Parameters - ---------- - msg_key: str - The header of message. - requests: list - The requests - checker: callable - The checker for the response. - wait_time: int - The wait time between retry in second. - max_retry: int - The max retry time. - - Returns - ------- - response: dict - The gathered response. - """ - - responses = {} - barrier = self._world_size - for _ in range(barrier): - self._send_request(msg_key, requests) - responses.update(self._process_request(msg_key)) - barrier -= 1 - while barrier > 0: - info = self._wait_response(msg_key, checker, wait_time, max_retry) - if info: - info = {int(k): v for k, v in info.items()} - responses.update(info) - barrier -= 1 - responses = [responses[i] for i in sorted(responses)] - gathered_response = {} - for key in responses[0]: - if key in ("action", "reward"): - gather_mode = GatherMode.PARALLEL - else: - gather_mode = GatherMode.FIRST - gathered_response[key] = self._gather_values([r[key] for r in responses], gather_mode) - return gathered_response - - def _synchronize_feedback(self, msg_key: str, **feedback: dict): - """Broadcast feedback to workers - - Parameters - ---------- - msg_key: str - The header of message. - feedback: dict - The feedback body - """ - - def _check_feedback(body): - return body.get("feedback_receive", False) - - barrier = self._world_size - 1 - for _ in range(barrier): - self._send_request(msg_key, {"feedback_send": True, **feedback}) - while barrier > 0: - info = self._wait_response(msg_key, _check_feedback) - if info: - barrier -= 1 - - def _create_task(self, obj_type: str, act_type: str, worker_id: int) -> dict: - """Create message key base on types - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - worker_id: int - The worker id. - - Returns - ------- - config: dict - The config for the worker.execute. - """ - - if not self._states: - config = {} - else: - config = copy.deepcopy(self._states[-1]["response"]) - if obj_type == GYMObject.ENV and act_type == GYMAction.GET_STATE: - config["task_id"] = self._task_id - if obj_type == GYMObject.ENV and act_type == GYMAction.STEP: - config["actions"] = self._map_values(config["actions"], obj_type, worker_id) - config["task_id"] = self._task_id - elif obj_type == GYMObject.AGENT and act_type in (GYMAction.CHOOSE_ACTION, GYMAction.STORE): - config["task_id"] = self._task_id - return config - - def _map_values(self, values: List[Any], obj_type: str, worker_id: int) -> List[Any]: - """Map the values for worker - - Parameters - ---------- - values: list - The global values, - obj_type: str - The object type, should be one of GYMObject. - worker_id: int - The worker id. - - Returns - ------- - values: list - The values for the worker. - """ - - world_ids = self._get_world_ids(obj_type) - tile_size = len(values) // len(world_ids) - if len(values) % len(world_ids) != 0: - tile_size += 1 - worker_idx = world_ids.index(worker_id) - start = worker_idx * tile_size - end = min((worker_idx + 1) * tile_size, len(values)) - return values[start:end] - - def _gather_values(self, values: List[Any], gather_mode: str) -> Any: - """Gather the values - - Parameters - ---------- - values: list - The global values, - gather_mode: str - The gather mode should be in GatherMode. - - Returns - ------- - value: - The gathered value. - """ - - if gather_mode == GatherMode.FIRST or len(values) == 1: - return values[0] - if gather_mode == GatherMode.PARALLEL: - return values - if gather_mode in (GatherMode.REDUCE_MEAN, GatherMode.REDUCE_SUM): - if all(msc_utils.MSCArray.is_array(v) for v in values): - value_sum = np.array([msc_utils.cast_array(v) for v in values]).sum(axis=1) - else: - value_sum = reduce(lambda x, y: x + y, values) - if gather_mode == GatherMode.REDUCE_SUM: - return value_sum - return value_sum / len(values) - raise NotImplementedError("Gather mode {} is not supported") - - @property - def service_type(self): - return "main" - - -class NodeService(BaseService): - """Normal service for gym""" - - def _connect(self): - msg_key = self._to_msg_key(GYMObject.SERVICE, GYMAction.SETUP) - env_worker_ids = self._get_worker_ids(GYMObject.ENV) - agent_worker_ids = self._get_worker_ids(GYMObject.AGENT) - - def _check_request(body): - return "world_id" in body - - info = self._wait_request(msg_key, _check_request) - world_id = info["world_id"] - self._send_response( - msg_key, {"env_worker_ids": env_worker_ids, "agent_worker_ids": agent_worker_ids} - ) - info = self._feedback(msg_key) - return world_id, info["env_world_ids"], info["agent_world_ids"] - - def _feedback(self, msg_key: str) -> dict: - """Send feed back to main service - - Parameters - ---------- - msg_key: str - The header of message. - - Returns - ------- - response: dict - The recived feedback. - """ - - def _check_feedback(body): - return body.get("feedback_send", False) - - response = self._wait_request(msg_key, _check_feedback) - self._send_response(msg_key, {"feedback_receive": True}) - response = self._process_response(msg_key, response) - return response - - def _execute(self, obj_type: str, act_type: str) -> dict: - """Execute the service - - Parameters - ---------- - obj_type: str - The object type, should be one of GYMObject. - act_type: str - The action type, should be one of GYMAction. - - Returns - ------- - state: dict - The state after the execute. - """ - - msg_key = self._to_msg_key(obj_type, act_type) - info = self._process_request(msg_key) - self._send_response(msg_key, info) - return self._feedback(msg_key) - - @property - def service_type(self): - return "node" diff --git a/python/tvm/contrib/msc/core/gym/control/worker.py b/python/tvm/contrib/msc/core/gym/control/worker.py deleted file mode 100644 index c5b9d66251f0..000000000000 --- a/python/tvm/contrib/msc/core/gym/control/worker.py +++ /dev/null @@ -1,221 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.control.worker""" - -from typing import Any - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMAction, GYMObject - - -class BaseGymWorker: - """Basic worker for gym - - Parameters - ---------- - name: str - The worker name. - workspace: MSCDirectory - The worksapce. - worker_id: int - The worker_id. - worker_cls: class - The worker class. - worker_config: dict - The worker config. - """ - - def __init__( - self, - name: str, - workspace: msc_utils.MSCDirectory, - worker_id: int, - worker_cls: Any, - worker_config: dict, - ): - self._name = name - self._worker_id = worker_id - debug_level = worker_config.get("debug_level", 0) - if "logger" not in worker_config: - verbose = "debug" if debug_level > 0 else "info" - worker_config["logger"] = msc_utils.create_file_logger( - verbose, workspace.relpath(f"{self.obj_type.upper()}.{worker_id}_LOG") - ) - if "workspace" not in worker_config: - worker_config["workspace"] = workspace - worker_config["name"] = name - self._worker_impl = worker_cls(**worker_config) - - def __str__(self): - return f"<{self.obj_type}>: {self._name}({self._worker_id})" - - def execute(self, act_type: str, **kwargs) -> Any: - """Execute the worker - - Parameters - ---------- - act_type: str - The action type, should be one of GYMAction. - kwargs: dict - The kwargs for execute. - - Returns - ------- - response: dict - The execute result. - """ - - raise NotImplementedError("execute is not implemented in " + str(self.__class__)) - - @property - def obj_type(self): - return GYMObject.BASE - - @property - def name(self): - return self._name - - @property - def worker_id(self): - return self._worker_id - - -class EnvGymWorker(BaseGymWorker): - """Env worker for gym""" - - def execute(self, act_type: str, **kwargs) -> Any: - """Execute the worker - - Parameters - ---------- - act_type: str - The action type, should be one of GYMAction. - kwargs: dict - The kwargs for execute. - - Returns - ------- - response: dict - The execute result. - """ - - response = {} - if act_type == GYMAction.INIT: - max_task, baseline = self._worker_impl.init() - response.update({"max_task": max_task, "baseline": baseline}) - elif act_type == GYMAction.RESET: - self._worker_impl.reset() - elif act_type == GYMAction.GET_STATE: - observation, action_space = self._worker_impl.get_state(kwargs["task_id"]) - response.update({"observation": observation, "action_space": action_space}) - elif act_type == GYMAction.STEP: - rewards = self._worker_impl.step(**kwargs) - response.update({"rewards": rewards}) - elif act_type == GYMAction.SUMMARY: - plan = self._worker_impl.summary(**kwargs) - response.update({"plan": plan}) - elif act_type == GYMAction.CLEANUP: - self._worker_impl.destory() - return response - - @property - def obj_type(self): - return GYMObject.ENV - - -class AgentGymWorker(BaseGymWorker): - """Agent worker for gym""" - - def execute(self, act_type: str, **kwargs) -> Any: - """Execute the worker - - Parameters - ---------- - act_type: str - The action type, should be one of GYMAction. - kwargs: dict - The kwargs for execute. - - Returns - ------- - response: dict - The execute result. - """ - - response = {} - if act_type == GYMAction.INIT: - self._worker_impl.init(**kwargs) - elif act_type == GYMAction.RESET: - self._worker_impl.reset() - elif act_type == GYMAction.CHOOSE_ACTION: - actions = self._worker_impl.choose_action(**kwargs) - response.update({"actions": actions}) - elif act_type == GYMAction.STORE: - next_task = self._worker_impl.store(**kwargs) - response.update({"next_task": next_task}) - elif act_type == GYMAction.LEARN: - actions, rewards = self._worker_impl.learn() - response.update({"actions": actions, "rewards": rewards}) - elif act_type == GYMAction.CLEANUP: - self._worker_impl.destory() - return response - - @property - def obj_type(self): - return GYMObject.AGENT - - -class WorkerFactory: - """The Factory for workers""" - - @classmethod - def create(cls, name: str, workspace: msc_utils.MSCDirectory, config: dict) -> BaseGymWorker: - """Create worker - - Parameters - ---------- - name: str - The name of worker, should be in type. - workspace: MSCDirectory - The worksapce. - worker_id: int - The worker_id. - worker_cls: class - The worker class. - worker_config: dict - The worker config. - - Returns - ------- - worker: BaseGymWorker - The create worker. - """ - - def _get_worker_cls(obj: str): - worker_type = config.pop("role_type") if "role_type" in config else "default" - worker_cls = msc_utils.get_registered_gym_object(obj, worker_type) - assert worker_cls, f"Can not find worker class for {obj}:{worker_type}" - return worker_cls - - obj_type, worker_id = name.split(":") - if obj_type == GYMObject.ENV: - worker_cls = _get_worker_cls(obj_type) - return EnvGymWorker(name, workspace, int(worker_id), worker_cls, config) - if obj_type == GYMObject.AGENT: - worker_cls = _get_worker_cls(obj_type) - return AgentGymWorker(name, workspace, int(worker_id), worker_cls, config) - raise TypeError(f"Worker for {obj_type} is not supported") diff --git a/python/tvm/contrib/msc/core/gym/environment/__init__.py b/python/tvm/contrib/msc/core/gym/environment/__init__.py deleted file mode 100644 index 211b02d32f3a..000000000000 --- a/python/tvm/contrib/msc/core/gym/environment/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.environment""" - -from .method import * -from .prune_env import * -from .quantize_env import * diff --git a/python/tvm/contrib/msc/core/gym/environment/base_env.py b/python/tvm/contrib/msc/core/gym/environment/base_env.py deleted file mode 100644 index 24a986bfab72..000000000000 --- a/python/tvm/contrib/msc/core/gym/environment/base_env.py +++ /dev/null @@ -1,428 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.base_env""" - -import copy -import logging -from typing import Any, Dict, List, Optional, Tuple, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMObject -from tvm.contrib.msc.core.runtime import BaseRunner -from tvm.contrib.msc.core.tools import BaseTool - - -class BaseEnv: - """Basic Environment of MSC.Gym - - Parameters - ---------- - runner: BaseRunner - The runner. - data_loader: - The data_loader - workspace: MSCDirectory - The worksapce. - executors: dict - The executors of the environment. - knowledge: dict - The predefined knowledge. - options: dict - The extra options for the environment. - debug_level: int - The debug level. - logger: logging.Logger - The logger - """ - - def __init__( - self, - name: str, - runner: BaseRunner, - data_loader: Any, - workspace: msc_utils.MSCDirectory, - executors: dict, - knowledge: Optional[dict] = None, - options: Optional[dict] = None, - max_tasks: int = -1, - debug_level: int = 0, - logger: Optional[logging.Logger] = None, - ): - self._name = name - self._runner = runner - self._data_loader = data_loader - self._workspace = workspace - self._knowledge = msc_utils.load_dict(knowledge) - self._executors = self._parse_executors(msc_utils.copy_dict(executors)) - self._options = options or {} - self._max_tasks = max_tasks - self._debug_level = debug_level - self._logger = logger or msc_utils.get_global_logger() - self._logger.info(msc_utils.msg_block(self.env_mark("SETUP"), self.setup())) - - def _parse_executors(self, executors_dict: dict) -> Dict[str, Tuple[callable, dict]]: - """Parse the executors - - Parameters - ---------- - executors_dict: dict - The given executors. - - Returns - ------- - executors_dict: dict - The parsed executors. - """ - - executors = {} - for name, raw_config in executors_dict.items(): - method_type = ( - raw_config.pop("method_type") if "method_type" in raw_config else "default" - ) - method_cls = msc_utils.get_registered_gym_method(GYMObject.ENV, method_type) - assert method_cls, f"Can not find method cls for {GYMObject.ENV}:{method_type}" - assert "method" in raw_config, "method should be given to find enviironment method" - method_name, method = raw_config.pop("method"), None - if hasattr(method_cls, method_name): - method = getattr(method_cls, method_name) - if not method: - method = msc_utils.get_registered_func(method_name) - assert method, "Can not find method " + str(method_name) - executors[name] = (method_name, method, copy.deepcopy(raw_config)) - return executors - - def setup(self) -> dict: - """Setup the environment - - Returns - ------- - info: dict - The setup info. - """ - - self._cache_dir = self._workspace.create_dir("Cache") - self._tool = None - self._tasks = [] - return { - "name": self._name, - "runner": self._runner, - "data_loader": self._data_loader, - "workspace": self._workspace, - "executors": {k: f"{v[0]}({v[2]})" for k, v in self._executors.items()}, - "options": self._options, - "max_tasks": self._max_tasks, - "debug_level": self._debug_level, - } - - def init(self) -> Tuple[int, Dict[str, Any]]: - """Init the agent - - Returns - ------- - max_tasks: int - The max task for agent. - baseline: dict - The baseline of environment. - """ - - self._runner.change_logger(self._logger) - # save cache for tasks - self._runner.save_cache(self._cache_dir) - self._tool = self._init_tool() - # create tasks - self._tasks = self._execute("create_tasks", self._tool) - if self._max_tasks > 0: - self._tasks = self._tasks[: self._max_tasks] - # get baseline - self._tool.disable() - self._runner.build(self._cache_dir, force_build=True, disable_tools=[self._tool.tool_type]) - baseline = self._reward_runner(-1) - self._tool.enable() - tasks_info = {"tasks_num": len(self._tasks), "tasks": self._tasks} - self._logger.info(msc_utils.msg_block(self.env_mark("TASKS"), tasks_info)) - return len(self._tasks), baseline - - def _init_tool(self) -> BaseTool: - """Get the main tool""" - - raise NotImplementedError("_init_tool is not implemented in BaseEnv") - - def reset(self) -> Tuple[List[float], List[dict]]: - """Reset the environment - - Returns - ------- - observation: list - The next observation. - action_space: list - The next action space. - """ - - return None - - def get_state(self, task_id: int) -> Tuple[List[float], List[dict]]: - """Get the state - - Parameters - ---------- - task_id: int - The current task id. - - Returns - ------- - observation: list - The next observation. - action_space: list - The next action space. - """ - - if "observation" in self._executors: - observation = self._execute("observation", task_id) - else: - observation = [task_id] - if "action_space" in self._executors: - action_space = self._execute("action_space", task_id) - else: - action_space = list(range(5)) - return observation, action_space - - def step(self, actions: List[dict], task_id: int) -> Tuple[List[float], List[dict], List[dict]]: - """Step and get rewards - - Parameters - ---------- - actions: list - The current actions. - task_id: int - The current task id. - - Returns - ------- - observation: list - The next observation. - action_space: list - The next action space. - rewards: list - The rewards - """ - - rewards = [] - for idx, action in enumerate(actions): - self._update_tool(action, task_id) - self._runner.build(self._cache_dir, force_build=True) - rewards.append(self._reward_runner(task_id)) - self._logger.info( - "Task[%d/%d] Action[%d/%d] %s -> reward %s", - task_id, - len(self._tasks), - idx, - len(actions), - action, - rewards[-1], - ) - return rewards - - def _update_tool(self, action: dict, task_id: int): - """Update the tool - - Parameters - ---------- - action: dict - The current action. - task_id: int - The current task id. - """ - - raise NotImplementedError("_update_tool is not implemented in BaseEnv") - - def summary(self, actions: List[dict], rewards: List[dict]) -> dict: - """Summary the final plan - - Parameters - ---------- - actions: list - The final actions. - rewards: list - The final rewards. - - Returns - ------- - plan: dict - The final plan. - """ - - self._logger.info("Env Summary with %d actions, %d rewards", len(actions), len(rewards)) - return self._summary(actions, rewards) - - def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: - """Summary the final plan - - Parameters - ---------- - actions: list - The final actions. - rewards: list - The final rewards. - - Returns - ------- - knowledge: dict| str - The learned knowledge or file. - """ - - raise NotImplementedError("_summary is not implemented in BaseEnv") - - def _update_strategy(self, strategy: dict, **kwargs) -> dict: - """Update startegy - - Parameters - ---------- - startegy: dict - The strategy. - kwargs: dict - The kwargs. - - Returns - ------- - strategy: dict - The updated strategy. - """ - - for t_type, method_def in strategy["methods"].items(): - if isinstance(method_def, str): - strategy["methods"][t_type] = {"method_name": method_def, **kwargs} - elif isinstance(method_def, dict): - method_def.update(kwargs) - return strategy - - def _get_strategy(self, action: dict, task_id: int) -> dict: - """Get strategy from task_id - - Parameters - ---------- - action: float - The current action. - task_id: int - The current task id. - - Returns - ------- - strategy: dict - The strategy. - """ - - strategy = msc_utils.copy_dict(self.get_task(task_id)) - return self._update_strategy(strategy, **action) - - def get_task(self, task_id: int) -> dict: - """Get task according to task_id - - Parameters - ---------- - task_id: int - The task id. - - Returns - ------- - task_config: dict - The task config. - """ - - return self._tasks[task_id] - - def destory(self): - """Destory the environment""" - - return None - - def _reward_runner(self, task_id: int) -> dict: - """Reward runner for current task - - Parameters - ---------- - task_id: int - The current task id. - - Returns - ------- - reward: dict - The reward - """ - - if "reward_runner" in self._executors: - return self._execute("reward_runner", self._runner, self._data_loader, task_id) - elif "reward_outputs" in self._executors: - reward = {} - for inputs in self._data_loader(): - outputs = self._runner.run(inputs) - reward = self._execute("reward_outputs", reward, outputs, task_id) - return reward - else: - raise Exception("reward_runner or reward_outputs should be given in executors") - - def _execute(self, name: str, *args, **kwargs) -> Any: - """Run executor - - Parameters - ---------- - name: str - The executor name. - args: list - The arguments for execute. - kwargs: dict - The key word arguments for execute. - - Returns - ------- - res: - The execute result. - """ - - assert name in self._executors, ( - f"Can not find {name} in executors: {self._executors.keys()}" - ) - _, method, config = self._executors[name] - kwargs.update({k: v for k, v in config.items() if k not in kwargs}) - return method(self, *args, **kwargs) - - def env_mark(self, msg: Any) -> str: - """Mark the message with env info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"ENV({self.role_type()}) {msg}" - - @property - def tool(self): - return self._tool - - @classmethod - def role(cls): - return GYMObject.ENV - - @classmethod - def role_type(cls): - return "base" diff --git a/python/tvm/contrib/msc/core/gym/environment/method.py b/python/tvm/contrib/msc/core/gym/environment/method.py deleted file mode 100644 index 92768ca9ac89..000000000000 --- a/python/tvm/contrib/msc/core/gym/environment/method.py +++ /dev/null @@ -1,207 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.gym.agent.method""" - -from typing import Any, List, Optional - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.namespace import GYMObject -from tvm.contrib.msc.core.runtime import BaseRunner -from tvm.contrib.msc.core.tools import BaseTool - - -@msc_utils.register_gym_method -class EnvMethod: - """Default prune method""" - - @classmethod - def tasks_tool_extract(cls, env: Any, tool: BaseTool, **kwargs) -> List[dict]: - """Extract tasks from tool - - Parameters - ---------- - env: BaseEnv - The evironment. - tool: BaseTool - The main tool - kwargs: dict - The kwargs for create tasks. - - Returns - ------- - tasks: list - The tasks for environment. - """ - - return tool.create_tasks(**kwargs) - - @classmethod - def reward_compare_baseline( - cls, - env: Any, - runner: BaseRunner, - data_loader: callable, - task_id: int, - loss_type: str = "lp_norm", - loss_config: Optional[dict] = None, - ) -> dict: - """Reward runner with baseline - - Parameters - ---------- - env: BaseEnv - The evironment. - runner: BaseRunner - The runner. - data_loader: callable - The data loader. - task_id: int - The task id. - loss_type: str - The loss type - loss_config: dict - The loss config - - Returns - ------- - reward: dict - The reward. - """ - - datas_path = env._workspace.create_dir("Baseline").path - if task_id == -1: - with msc_utils.SimpleDataSaver(datas_path) as saver: - for inputs in data_loader(): - outputs = runner.run(inputs) - saver.save_datas(outputs) - return {"loss": 1} - - loss_config = loss_config or {} - loader, loss = msc_utils.SimpleDataLoader(datas_path), 0 - - def _get_loss(golden, result): - if loss_type == "lp_norm": - power = loss_config.get("power", 2) - return np.mean(np.power(np.abs(golden - result), power)) - raise NotImplementedError(f"loss type {loss_type} is not implemented") - - for idx, inputs in enumerate(data_loader()): - outputs = runner.run(inputs) - baseline = loader[idx] - for name, data in outputs.items(): - loss += _get_loss(baseline[name], msc_utils.cast_array(data)) - return {"loss": loss / len(loader)} - - @classmethod - def action_linear_space( - cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 - ) -> List[float]: - """Get linear action space - - Parameters - ---------- - env: BaseEnv - The evironment. - task_id: int - The task id. - start: float - The start value. - end: float - The end value. - step: float - The step value. - - Returns - ------- - actions: list - The actions. - """ - - actions = [start] - while actions[-1] < end: - actions.append(actions[-1] + step) - return actions - - @classmethod - def action_prune_density( - cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 - ) -> List[dict]: - """Get linear density - - Parameters - ---------- - env: BaseEnv - The evironment. - task_id: int - The task id. - start: float - The start value. - end: float - The end value. - step: float - The step value. - - Returns - ------- - actions: list - The actions. - """ - - return [{"density": a} for a in cls.action_linear_space(env, task_id, start, end, step)] - - @classmethod - def action_quantize_scale( - cls, env: Any, task_id: int, start: float = 0.1, end: float = 0.9, step: float = 0.1 - ) -> List[dict]: - """Get linear density - - Parameters - ---------- - env: BaseEnv - The evironment. - task_id: int - The task id. - start: float - The start value. - end: float - The end value. - step: float - The step value. - - Returns - ------- - actions: list - The actions. - """ - - task = env.get_task(task_id) - plan = env.tool.plan[task["tensor_ids"][0]] - return [ - {"scale": plan["scale"] * a} - for a in cls.action_linear_space(env, task_id, start, end, step) - ] - - @classmethod - def role(cls): - return GYMObject.ENV - - @classmethod - def method_type(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/gym/environment/prune_env.py b/python/tvm/contrib/msc/core/gym/environment/prune_env.py deleted file mode 100644 index 87b777809a6d..000000000000 --- a/python/tvm/contrib/msc/core/gym/environment/prune_env.py +++ /dev/null @@ -1,97 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF005 -"""tvm.contrib.msc.core.gym.prune_env""" - -from typing import List, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import BaseTool, ToolType - -from .base_env import BaseEnv - - -@msc_utils.register_gym_object -class PruneEnv(BaseEnv): - """Environment for prune""" - - def _init_tool(self) -> BaseTool: - """Get the main tool""" - - config = self._runner.get_tool_config(ToolType.PRUNER) - self._meta_strategys = msc_utils.copy_dict(config["strategys"]) - self._meta_strategys = [self._update_strategy(s, density=1) for s in self._meta_strategys] - tool = self._runner.get_tool(ToolType.PRUNER) - tool.change_strategys(self._meta_strategys) - return tool - - def _update_tool(self, action: dict, task_id: int): - """Update the tool - - Parameters - ---------- - action: dict - The current action. - task_id: int - The current task id. - """ - - task_strategy = self._get_strategy(action, task_id) - self._apply_strategys(self._meta_strategys + [task_strategy]) - - def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: - """Summary the final plan - - Parameters - ---------- - actions: list - The final actions. - rewards: list - The final rewards. - - Returns - ------- - knowledge: dict| str - The learned knowledge or file. - """ - - strategys = self._meta_strategys + [ - self._get_strategy(act, idx) for idx, act in enumerate(actions) - ] - return self._apply_strategys(strategys) - - def _apply_strategys(self, strategys: List[dict]) -> str: - """Apply the strategys - - Parameters - ---------- - strategys: list - The given strategys - - Returns - ------- - plan_file: str - The plan after strategys applied. - """ - - self._tool.change_strategys(strategys) - self._runner.build(self._cache_dir, force_build=True) - return self._runner.make_plan(self._tool.tool_type(), self._data_loader) - - @classmethod - def role_type(cls): - return msc_utils.MSCStage.PRUNE + ".default" diff --git a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py b/python/tvm/contrib/msc/core/gym/environment/quantize_env.py deleted file mode 100644 index 20b880623c3d..000000000000 --- a/python/tvm/contrib/msc/core/gym/environment/quantize_env.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.quantize_env""" - -from typing import List, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import BaseTool, ToolType - -from .base_env import BaseEnv - - -@msc_utils.register_gym_object -class QuantizeEnv(BaseEnv): - """Environment for quantize""" - - def _init_tool(self) -> BaseTool: - """Get the main tool""" - - self._runner.make_plan(ToolType.QUANTIZER, self._data_loader) - return self._runner.get_tool(ToolType.QUANTIZER) - - def _update_tool(self, action: dict, task_id: int): - """Update the tool - - Parameters - ---------- - action: dict - The current action. - task_id: int - The current task id. - """ - - self._tool.change_strategys([self._get_strategy(action, task_id)]) - - def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]: - """Summary the final plan - - Parameters - ---------- - actions: list - The final actions. - rewards: list - The final rewards. - - Returns - ------- - knowledge: dict| str - The learned knowledge or file. - """ - - strategys = self.tool._parse_strategys( - [self._get_strategy(act, idx) for idx, act in enumerate(actions)] - ) - plan = self.tool.plan - for name, info in plan.items(): - if name not in strategys: - continue - info.update(strategys[name].get_executor(msc_utils.MSCStage.QUANTIZE).config) - summary_file = msc_utils.get_cache_dir().relpath("gym_summary.json") - return msc_utils.save_dict(plan, summary_file) - - @classmethod - def role_type(cls): - return msc_utils.MSCStage.QUANTIZE + ".default" diff --git a/python/tvm/contrib/msc/core/gym/namespace.py b/python/tvm/contrib/msc/core/gym/namespace.py deleted file mode 100644 index 10417e9e625a..000000000000 --- a/python/tvm/contrib/msc/core/gym/namespace.py +++ /dev/null @@ -1,40 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.gym.namespace""" - - -class GYMObject: - """Enum all gym objects""" - - BASE = "base" - ENV = "env" - AGENT = "agent" - SERVICE = "service" - - -class GYMAction: - """Enum all gym actions""" - - INIT = "init" - RESET = "reset" - GET_STATE = "get_state" - CHOOSE_ACTION = "choose_action" - STEP = "step" - STORE = "store" - LEARN = "learn" - SUMMARY = "summary" - CLEANUP = "cleanup" diff --git a/python/tvm/contrib/msc/core/ir/__init__.py b/python/tvm/contrib/msc/core/ir/__init__.py deleted file mode 100644 index ce23a2dd8b27..000000000000 --- a/python/tvm/contrib/msc/core/ir/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.ir""" - -from .graph import * diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py deleted file mode 100644 index 7c930fca245a..000000000000 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ /dev/null @@ -1,1102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.ir.graph""" - -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np -import tvm_ffi - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.runtime import Object - - -@tvm_ffi.register_object("msc.core.MSCTensor") -class MSCTensor(Object): - """Tensor in MSCGraph - - Parameters - ---------- - name: string - The name of the tensor. - dtype: string or np.dtype or DataType - The data type the tensor. - layout: string - The layout of the tensor. - shape: list - The shape of the tensor. - alias: string - The alias of the tensor. - prims: list - The prims of the tensor. - """ - - def __init__( - self, - name: str, - dtype: Union[str, np.dtype, tvm.DataType], - layout: str, - shape: List[int], - alias: Optional[str] = None, - prims: Optional[List[str]] = None, - ): - if not isinstance(dtype, tvm.DataType): - dtype = tvm.DataType(dtype) - self.__init_handle_by_constructor__( - _ffi_api.MSCTensor, name, dtype, layout, shape, alias or "", prims or [] - ) - - def get_shape(self, with_prims: bool = False) -> List[Union[int, str]]: - """Get shape of the tensor - - Parameters - ------- - with_prims: bool - Whether get shape with prims. - - Returns - ------- - shape: list - The shape of tensor. - """ - - if not self.prims or not with_prims: - return [int(i) for i in self.shape] - return [int(p) if p.isdigit() else p for p in self.prims] - - def get_size(self) -> int: - return int(_ffi_api.MSCTensorGetSize(self)) - - def dim_at(self, axis: Union[int, str]) -> int: - if isinstance(axis, int): - return int(self.shape[axis]) - return int(_ffi_api.MSCTensorDimAt(self, axis)) - - def layout_of(self, axis: str) -> int: - return self.layout.index_of(axis) - - def set_alias(self, alias: str): - """Set alis for the tensor - - Parameters - ------- - alias: str - The alias. - """ - - _ffi_api.MSCTensorSetAlias(self, alias) - - def equal(self, other: Object) -> bool: - """A fast method to check if two nodes are same. - - Parameters - ------- - other: MSCTensor - The tensor to be compared. - - Returns - ------- - equal: bool - Whether two tensors are the same. - """ - - if not isinstance(other, MSCTensor): - return False - if self.get_shape(True) != other.get_shape(True): - return False - if self.dtype != other.dtype: - return False - return True - - def to_json(self) -> str: - """Dump the tensor to json. - - Returns - ------- - tensor_json: string - The tensor in json format. - """ - - return _ffi_api.MSCTensorToJson(self) - - def inspect(self) -> dict: - """Extract important info of the tensor. - - Returns - ------- - tensor_des: dict - The tensor description in json format. - """ - - tensor_des = {"name": self.alias, "shape": self.get_shape(True), "dtype": self.dtype_name} - tensor_des["layout"] = self.layout.name if self.layout else "" - return tensor_des - - @classmethod - def from_json(cls, json_str: str, **options) -> object: - """Load the tensor from json. - - Parameters - ---------- - json_str: string - The file_path or json string. - options: dict - The items to be changed. - - Returns - ------- - tensor: MSCTensor - The tensor. - """ - - dict_obj = msc_utils.load_dict(json_str) - dict_obj.update(options) - return _ffi_api.MSCTensorFromJson(msc_utils.dump_dict(dict_obj)) - - def clone(self, **options) -> object: - """Clone the tensor. - - Parameters - ---------- - json_str: string - The file_path or json string. - options: dict - The items to be changed. - - Returns - ------- - new_tensor: MSCTensor - The cloned tensor. - """ - - return MSCTensor.from_json(self.to_json(), **options) - - @property - def dtype_name(self) -> str: - return _ffi_api.MSCTensorDTypeName(self) - - @property - def ndim(self) -> int: - return len(self.shape) - - -@tvm_ffi.register_object("msc.core.BaseJoint") -class BaseJoint(Object): - """Base class of all MSC Nodes.""" - - -@tvm_ffi.register_object("msc.core.MSCJoint") -class MSCJoint(BaseJoint): - """Node in MSCGraph - - Parameters - ---------- - index: int - The index of the node. - name: string - The name of the node. - shared_ref: string - The share reference of the node. - optype: string - The optype of the node. - attrs: dict - The attributes of the node. - inputs: list> - The inputs of the node in format . - outputs: list - The outputs of the node. - weights: dict - The weights of the node. - """ - - def __init__( - self, - index: int, - name: str, - shared_ref: str, - optype: str, - attrs: Dict[str, str], - inputs: List[Tuple[BaseJoint, int]], - outputs: List[MSCTensor], - weights: Dict[str, MSCTensor], - ): - parents = [i[0] for i in inputs] - out_indices = [i[1] for i in inputs] - self.__init_handle_by_constructor__( - _ffi_api.MSCJoint, - index, - name, - shared_ref, - optype, - attrs, - parents, - out_indices, - outputs, - weights, - ) - - def input_at(self, idx: int) -> MSCTensor: - """Get input at idx. - - Parameters - ---------- - idx: int - The index of input. - - Returns - ------- - input: MSCTensor - The input Tensor. - """ - - return _ffi_api.MSCJointInputAt(self, idx) - - def output_at(self, idx: int) -> MSCTensor: - """Get output at idx. - - Parameters - ---------- - idx: int - The index of output. - - Returns - ------- - output: MSCTensor - The output Tensor. - """ - - return _ffi_api.MSCJointOutputAt(self, idx) - - def weight_at(self, wtype: str) -> MSCTensor: - """Get weight from reference. - - Parameters - ---------- - wtype: str - The type of weight. - - Returns - ------- - weight: MSCTensor - The weight Tensor. - """ - - return _ffi_api.MSCJointWeightAt(self, wtype) - - def weight_type(self, name: str) -> str: - """Get the weight type of weight - - Parameters - ---------- - name: str - The name of weight. - - Returns - ------- - wtype: str - The type of weight. - """ - - for w_type, weight in self.get_weights().items(): - if weight.name == name: - return w_type - raise Exception("Can not find weight type for " + name) - - def get_inputs(self) -> List[MSCTensor]: - """Get all the inputs. - - Returns - ------- - inputs: list - The input Tensors. - """ - - return _ffi_api.MSCJointGetInputs(self) - - def get_outputs(self) -> List[MSCTensor]: - """Get all the outputs. - - Returns - ------- - outputs: list - The output Tensors. - """ - - return _ffi_api.MSCJointGetOutputs(self) - - def get_weights(self) -> Dict[str, MSCTensor]: - """Get all the weights. - - Returns - ------- - weights: dict - The weight Tensors. - """ - - src_weights = _ffi_api.MSCJointGetWeights(self) - return {wtype: src_weights[wtype] for wtype in src_weights} - - def get_attrs(self) -> Dict[str, str]: - """Get all the attributes from node - - Returns - ------- - attributes: dict - The attributes of node. - """ - - return _ffi_api.MSCJointGetAttrs(self) - - def get_attr(self, key: str, default: Optional[Any] = None) -> str: - """Get the attribute of key from node - - Parameters - ------- - key: str - The key of the attribute. - default: Any - The default value when key is missing. - - Returns - ------- - attribute: str - The attributes of node. - """ - - return self.get_attrs().get(key, default) - - def has_attr(self, key: str) -> bool: - """Check if key in attributes - - Parameters - ------- - key: str - The key of the attribute. - - Returns - ------- - has_attr: bool - Whether the key in the attributes. - """ - - return bool(_ffi_api.MSCJointHasAttr(self, key)) - - def equal(self, other: BaseJoint) -> bool: - """A fast method to check if two nodes are same. - - Parameters - ------- - other: MSCJoint - The node to be compared. - - Returns - ------- - equal: bool - Whether two nodes are the same. - """ - - if not isinstance(other, MSCJoint): - return False - if len(self.get_inputs()) != len(other.get_inputs()): - return False - if len(self.get_inputs()) != len(other.get_inputs()): - return False - for s_i, o_i in zip(self.get_inputs(), other.get_inputs()): - if not s_i.equal(o_i): - return False - for s_o, o_o in zip(self.get_inputs(), other.get_inputs()): - if not s_o.equal(o_o): - return False - return msc_utils.dict_equal(self.get_attrs(), other.get_attrs()) - - -@tvm_ffi.register_object("msc.core.MSCPrim") -class MSCPrim(BaseJoint): - """Prim in MSCGraph - - Parameters - ---------- - index: int - The index of the prim. - name: string - The name of the prim. - optype: string - The optype of the prim. - attrs: dict - The attributes of the node. - parents: list - The parents of the prim. - """ - - def __init__( - self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint] - ): - self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents) - - -@tvm_ffi.register_object("msc.core.WeightJoint") -class WeightJoint(BaseJoint): - """Node in WeightGraph - - Parameters - ---------- - index: int - The index of the node. - name: string - The name of the node. - shared_ref: string - The share reference of the node. - optype: string - The optype of the node. - wtype: string - The weight type of the node. - strategy: string - The prune strategy of the node. - weight: MSCTensor - The weight of the node. - attrs: dict - The attributes of the node. - parents: list - The parents of the node. - friends: list - The friends of the node. - """ - - def __init__( - self, - index: int, - name: str, - shared_ref: str, - optype: str, - wtype: str, - strategy: str, - weight: MSCTensor, - attrs: Dict[str, str], - parents: List[BaseJoint], - friends: List[BaseJoint], - ): - self.__init_handle_by_constructor__( - _ffi_api.WeightJoint, - index, - name, - shared_ref, - optype, - wtype, - strategy, - weight, - attrs, - parents, - friends, - ) - - def set_attr(self, key: str, value: str): - """Set attribute to node - - Parameters - ------- - key: str - The key of the attribute. - value: str - The value. - """ - - _ffi_api.WeightJointSetAttr(self, key, value) - - def get_attrs(self) -> Dict[str, str]: - """Get all the attributes from node - - Returns - ------- - attributes: dict - The attributes of node. - """ - - return _ffi_api.WeightJointGetAttrs(self) - - def get_attr(self, key: str, default: Optional[Any] = None) -> str: - """Get the attribute of key from node - - Parameters - ------- - key: str - The key of the attribute. - default: Any - The default value when key is missing. - - Returns - ------- - attribute: str - The attributes of node. - """ - - return self.get_attrs().get(key, default) - - def has_attr(self, key: str) -> bool: - """Check if key in attributes - - Parameters - ------- - key: str - The key of the attribute. - - Returns - ------- - has_attr: bool - Whether the key in the attributes. - """ - - return bool(_ffi_api.WeightJointHasAttr(self, key)) - - -@tvm_ffi.register_object("msc.core.BaseGraph") -class BaseGraph(Object): - """Base class of all MSC Graphs.""" - - -@tvm_ffi.register_object("msc.core.MSCGraph") -class MSCGraph(BaseGraph): - """The MSCGraph - - Parameters - ---------- - name: string - The name of the graph. - nodes: list - The nodes of the graph. - input_names: list - The input names of the graph. - output_names: list - The output names of the graph. - """ - - def __init__( - self, - name: str, - nodes: List[MSCJoint], - input_names: List[str], - output_names: List[str], - ): - self.__init_handle_by_constructor__( - _ffi_api.MSCGraph, - name, - nodes, - input_names, - output_names, - ) - - def has_node(self, name: str) -> bool: - """Check if node in the graph. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - has_node: bool - Whether the node is in the graph - """ - - return bool(_ffi_api.MSCGraphHasNode(self, name)) - - def find_node(self, name: str) -> MSCJoint: - """Find node by name. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - node: MSCJoint - The found node. - """ - - return _ffi_api.MSCGraphFindNode(self, name) - - def find_prim(self, name: str) -> MSCPrim: - """Find prim by name. - - Parameters - ---------- - name: string - The name of the prim. - - Returns - ------- - prim: MSCPrim - The found prim. - """ - - return _ffi_api.MSCGraphFindPrim(self, name) - - def has_tensor(self, name: str) -> bool: - """Check if tensor in the graph. - - Parameters - ---------- - name: string - The name of the tensor. - - Returns - ------- - has_tensor: bool - Whether the tensor is in the graph - """ - - return bool(_ffi_api.MSCGraphHasTensor(self, name)) - - def find_tensor(self, name: str) -> MSCTensor: - """Find tensor by name. - - Parameters - ---------- - name: string - The name of the tensor. - - Returns - ------- - node: MSCTensor - The found tensor. - """ - - return _ffi_api.MSCGraphFindTensor(self, name) - - def set_tensor_alias(self, tensor: MSCTensor, alias: str): - """Set alis for the tensor - - Parameters - ------- - tensor: MSCTensor - The tensor. - alias: str - The alias. - """ - - _ffi_api.MSCGraphSetTensorAlias(self, tensor, alias) - - def find_producer(self, ref: Union[str, MSCTensor]) -> MSCJoint: - """Find producer by tensor_name or tensor. - - Parameters - ---------- - ref: string or MSCTensor - The name of the tensor or tensor. - - Returns - ------- - node: MSCJoint - The found prducer. - """ - - if isinstance(ref, MSCTensor): - return _ffi_api.MSCGraphFindProducer(self, ref.name) - return _ffi_api.MSCGraphFindProducer(self, ref) - - def find_consumers(self, ref: Union[str, MSCTensor]) -> List[MSCJoint]: - """Find consumers by tensor_name or tensor. - - Parameters - ---------- - ref: string or MSCTensor - The name of the tensor or tensor. - - Returns - ------- - node: list - The found consumers. - """ - - if isinstance(ref, MSCTensor): - return _ffi_api.MSCGraphFindConsumers(self, ref.name) - return _ffi_api.MSCGraphFindConsumers(self, ref) - - def get_nodes(self) -> Iterable[MSCJoint]: - """Get all the nodes in the graph. - - Returns - ------- - nodes: generator - The generator of nodes. - """ - - for n in self.node_names: - yield self.find_node(n) - - def get_prims(self) -> Iterable[MSCPrim]: - """Get all the prims in the graph. - - Returns - ------- - prims: generator - The generator of prims. - """ - - for n in self.prim_names: - yield self.find_prim(n) - - def get_weights(self) -> Iterable[MSCTensor]: - """Get all the weights in the graph. - - Returns - ------- - weights: generator - The generator of weights. - """ - - for node in self.get_nodes(): - yield from node.get_weights().values() - - def input_at(self, idx: int) -> MSCTensor: - """Get input at idx. - - Parameters - ---------- - idx: int - The index of input. - - Returns - ------- - input: MSCTensor - The input Tensor. - """ - - return _ffi_api.MSCGraphInputAt(self, idx) - - def output_at(self, idx: int) -> MSCTensor: - """Get output at idx. - - Parameters - ---------- - idx: int - The index of output. - - Returns - ------- - output: MSCTensor - The output Tensor. - """ - - return _ffi_api.MSCGraphOutputAt(self, idx) - - def get_inputs(self) -> List[MSCTensor]: - """Get all the inputs. - - Returns - ------- - inputs: list - The input Tensors. - """ - - return _ffi_api.MSCGraphGetInputs(self) - - def get_outputs(self) -> List[MSCTensor]: - """Get all the outputs. - - Returns - ------- - outputs: list - The output Tensors. - """ - - return _ffi_api.MSCGraphGetOutputs(self) - - def get_tensors(self) -> List[MSCTensor]: - """Get all the tensors. - - Returns - ------- - tensors: list - The Tensors. - """ - - for node in self.get_nodes(): - yield from node.get_inputs() - yield from node.get_weights().values() - yield from self.get_outputs() - - def to_json(self) -> str: - """Dump the graph to json. - - Returns - ------- - graph_json: string - The graph in json format. - """ - - return _ffi_api.MSCGraphToJson(self) - - def inspect(self) -> dict: - """Extract important info of the graph. - - Returns - ------- - graph_des: dict - The graph description in json format. - """ - - graph_des = { - "inputs": [i.inspect() for i in self.get_inputs()], - "outputs": [o.inspect() for o in self.get_outputs()], - "nodes": {"total": 0}, - } - for node in self.get_nodes(): - graph_des["nodes"].setdefault(node.optype, 0) - graph_des["nodes"]["total"] += 1 - graph_des["nodes"][node.optype] += 1 - prims = {"total": 0} - for prim in self.get_prims(): - prims.setdefault(prim.optype, 0) - prims["total"] += 1 - prims[prim.optype] += 1 - if prims["total"] > 0: - graph_des["prims"] = prims - return graph_des - - @classmethod - def from_json(cls, json_str: str) -> BaseGraph: - """Load the graph from json. - - Parameters - ---------- - json_str: string - The file_path or json string. - - Returns - ------- - graph: MSCgraph - The graph. - """ - - dict_obj = msc_utils.load_dict(json_str) - return _ffi_api.MSCGraphFromJson(msc_utils.dump_dict(dict_obj)) - - def clone(self) -> BaseGraph: - """Clone the graph. - - Returns - ------- - new_graph: MSCGraph - The cloned graph. - """ - - return MSCGraph.from_json(self.to_json()) - - def equal(self, other: BaseGraph) -> bool: - """A fast method to check if two graphs are same. - - Parameters - ------- - other: MSCGraph - The graph to be compared. - - Returns - ------- - equal: bool - Whether two graphs are the same. - """ - - if not isinstance(other, MSCGraph): - return False - if len(self.input_names) != len(other.input_names): - return False - if len(self.output_names) != len(other.output_names): - return False - if len(self.node_names) != len(other.node_names): - return False - for s_i, o_i in zip(self.get_inputs(), other.get_inputs()): - if not s_i.equal(o_i): - return False - for s_o, o_o in zip(self.get_outputs(), other.get_outputs()): - if not s_o.equal(o_o): - return False - for s_n, o_n in zip(self.get_nodes(), other.get_nodes()): - if not s_n.equal(o_n): - return False - return True - - def visualize(self, path: Optional[str] = None) -> str: - """Dump the graph to prototxt format. - - Parameters - ---------- - path: string - The file_path to save prototxt. - - Returns - ------- - graph_proto: string - The graph in prototxt format. - """ - - graph_proto = _ffi_api.MSCGraphToPrototxt(self) - if path: - with open(path, "w") as f: - f.write(graph_proto) - return graph_proto - - -@tvm_ffi.register_object("msc.core.WeightGraph") -class WeightGraph(BaseGraph): - """The WeightGraph - - Parameters - ---------- - name: string - The name of the graph. - nodes: list - The nodes of the graph. - """ - - def __init__( - self, - name: str, - nodes: List[WeightJoint], - ): - self.__init_handle_by_constructor__( - _ffi_api.WeightGraph, - name, - nodes, - ) - - def has_node(self, name: str) -> bool: - """Check if weight node in the graph. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - has_node: bool - Whether the node is in the graph - """ - - return bool(_ffi_api.WeightGraphHasNode(self, name)) - - def find_node(self, name: str) -> WeightJoint: - """Find weight node by name. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - node: MSCJoint - The found node. - """ - - return _ffi_api.WeightGraphFindNode(self, name) - - def get_nodes(self) -> Iterable[WeightJoint]: - """Get all the weight nodes in the graph. - - Returns - ------- - nodes: generator - The generator of nodes. - """ - - for n in self.node_names: - yield self.find_node(n) - - def to_json(self) -> str: - """Dump the graph to json. - - Returns - ------- - graph_json: string - The graph in json format. - """ - - return _ffi_api.WeightGraphToJson(self) - - def inspect(self) -> dict: - """Extract important info of the graph. - - Returns - ------- - graph_des: dict - The graph description in json format. - """ - - graph_des = { - "nodes": {"total": 0}, - } - for node in self.get_nodes(): - graph_des["nodes"]["total"] += 1 - if node.weight_type not in graph_des["nodes"]: - graph_des["nodes"][node.weight_type] = 1 - else: - graph_des["nodes"][node.weight_type] += 1 - return graph_des - - @classmethod - def from_json(cls, json_str: str) -> BaseGraph: - """Load the graph from json. - - Parameters - ---------- - json_str: string - The file_path or json string. - - Returns - ------- - graph: WeightGraph - The graph. - """ - - dict_obj = msc_utils.load_dict(json_str) - return _ffi_api.WeightGraphFromJson(msc_utils.dump_dict(dict_obj)) - - def clone(self) -> BaseGraph: - """Clone the graph. - - Returns - ------- - new_graph: MSCGraph - The cloned graph. - """ - - return MSCGraph.from_json(self.to_json()) - - def visualize(self, path: Optional[str] = None) -> str: - """Dump the graph to prototxt format. - - Parameters - ---------- - path: string - The file_path to save prototxt. - - Returns - ------- - graph_proto: string - The graph in prototxt format. - """ - - graph_proto = _ffi_api.WeightGraphToPrototxt(self) - if path: - with open(path, "w") as f: - f.write(graph_proto) - return graph_proto diff --git a/python/tvm/contrib/msc/core/runtime/__init__.py b/python/tvm/contrib/msc/core/runtime/__init__.py deleted file mode 100644 index 4530b2bc4a3f..000000000000 --- a/python/tvm/contrib/msc/core/runtime/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.runtime""" - -from .runner import * -from .jit import * diff --git a/python/tvm/contrib/msc/core/runtime/hook.py b/python/tvm/contrib/msc/core/runtime/hook.py deleted file mode 100644 index 8c42f2ab7c4d..000000000000 --- a/python/tvm/contrib/msc/core/runtime/hook.py +++ /dev/null @@ -1,194 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument, arguments-differ -"""tvm.contrib.msc.core.runtime.hook""" - -from typing import Any, Dict, List, Tuple, Union - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph - - -class RunnerHook: - """Hook for runner - - Parameters - ---------- - config: dict - The config of the func. - """ - - def __init__(self, config: dict): - self._config = config - - def __str__(self): - return f"{self.name()}({self._config})" - - def apply(self, runner: object, *args, **kwargs) -> Any: - """Apply the hook - - Parameters - ---------- - runner: - The runner context. - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - result: - The results. - """ - - kwargs.update({k: v for k, v in self._config.items() if k not in kwargs}) - return self._apply(runner, *args, **kwargs) - - def _apply(self, runner: object, *args, **kwargs): - """Apply the hook - - Parameters - ---------- - runner: - The runner context. - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - result: - The results. - """ - - raise NotImplementedError("default_func is not supported in " + str(self.__class__)) - - @classmethod - def name(cls): - return "base" - - -class CustomizedHook(RunnerHook): - """Hook for customized func - - Parameters - ---------- - func: callable/str - The function. - config: dict - The config of the func. - """ - - def __init__(self, func: Union[str, callable], config: dict): - super().__init__(config) - self._func = msc_utils.load_callable(func) - - def __str__(self): - return f"{self.name()} {self._func}({self._config})" - - def _apply(self, runner: object, *args, **kwargs): - """Apply the hook - - Parameters - ---------- - runner: - The runner context. - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - result: - The results. - """ - - return self._func(runner, *args, **kwargs) - - @classmethod - def name(cls): - return "customized" - - -@msc_utils.register_runner_hook -class UpdateWeightsHook(RunnerHook): - """Hook for update weights""" - - def _apply( - self, - runner: object, - graphs: List[MSCGraph], - weights: Dict[str, tvm.runtime.Tensor], - weights_path: str, - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Apply the default funcion - - Parameters - ------- - runner: - The runner context. - graphs: list - The translated graphs - weights: dict - The translated weights. - weights_path: str - The weights path. - - Returns - ------- - graphs: list - The updated graphs - weights: dict - The updated weights. - - """ - - with open(weights_path, "rb") as f: - new_weights = tvm.runtime.load_param_dict(f.read()) - weights.update({k: v for k, v in new_weights.items() if k in weights}) - return graphs, weights - - @classmethod - def name(cls): - return "update_weights" - - -def load_runner_hook(config: dict) -> Any: - """Load a registered hook - - Parameters - ---------- - config: dict - The config of the func. - - Returns - ------- - hook: RunnerHook - The hook - """ - - assert "hook" in config, "hook should be given to load hook" - hook_ref = config["hook"] - hook_config = {k: v for k, v in config.items() if k != "hook"} - hook_cls = msc_utils.get_registered_runner_hook(hook_ref) if isinstance(hook_ref, str) else None - if hook_cls: - return hook_cls(hook_config) - return CustomizedHook(hook_ref, hook_config) diff --git a/python/tvm/contrib/msc/core/runtime/jit.py b/python/tvm/contrib/msc/core/runtime/jit.py deleted file mode 100644 index 968c37c55a63..000000000000 --- a/python/tvm/contrib/msc/core/runtime/jit.py +++ /dev/null @@ -1,368 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.runtime.jit_model""" - -import logging -from typing import Any, Dict, List, Optional, Tuple, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - -from .runner import BaseRunner - - -class BaseJIT: - """Base Just-In-Time compile for msc - - Parameters - ---------- - model: - The model to be jit compile. - inputs: list - The input names. - outputs: list - The output names. - device: str - The device to build runnable. - training: bool - Whether compile model to trainable. - hooks: dict - The hooks for runners. - logger: logging.Logger - The logger - """ - - def __init__( - self, - model: Any, - inputs: List[str], - outputs: List[str], - device: str = "cpu", - training: bool = False, - hooks: Optional[dict] = None, - logger: Optional[logging.Logger] = None, - ): - self._model = model - self._jit_model = model - self._inputs = inputs - self._outputs = outputs - self._device = device if self.support_device(device) else "cpu" - self._training, self._trained = training, training - self._hooks = hooks or {} - self._runner_ctxs = {} - self._logger = logger or msc_utils.get_global_logger() - self._logger.info(msc_utils.msg_block(self.jit_mark("SETUP"), self.setup())) - - def setup(self) -> dict: - """Setup the jit - - Returns - ------- - info: dict - The setup info. - """ - - return { - "inputs": self._inputs, - "outputs": self._outputs, - "device": self._device, - "training": self._training, - "hooks": self._hooks, - } - - def run( - self, inputs: Union[List[Any], Dict[str, Any]], ret_type="native" - ) -> Union[List[Any], Dict[str, Any]]: - """Run the jit to get outputs - - Parameters - ------- - inputs: list or dict - The inputs in list or dict. - ret_type: str - The return type list| dict - - Returns - ------- - outputs: dict - The outputs in dict. - """ - - inputs = msc_utils.format_datas(inputs, self._inputs, style="dict") - outputs = self._call_jit(inputs) - if ret_type == "native": - return outputs - return msc_utils.format_datas(outputs, self._outputs, style=ret_type) - - def _call_jit(self, inputs: Dict[str, Any]) -> Any: - """Run the jit model - - Parameters - ---------- - inputs: - The inputs of model. - """ - - raise NotImplementedError("_call_jit is not implemented in " + str(self.__class__)) - - def set_runner(self, runner_name: str, runner: BaseRunner): - """Set runner in runner ctx - - Parameters - ---------- - runner_name: str - The runner name. - runner: BaseRunner - The runner. - """ - - self.get_runner_ctx(runner_name)["runner"] = runner - - def build(self): - """Build the jit model""" - - self._jit_model = self._build(self._model) - - def _build(self, model: Any) -> Any: - """Build the jit model - - Parameters - ---------- - model: - The model. - - Returns - ------- - jit_model: - The jit model. - """ - - raise NotImplementedError("_build is not implemented in " + str(self.__class__)) - - def make_plan(self, tool_type: str, data_loader: Any = None) -> str: - """Execute tool and get plan - - Parameters - ------- - tool_type: str - The tool type, should be in ToolType - data_loader: - The data loader. - - Returns - ------- - plan_file: str - The saved plan file. - """ - - tools = {n: r["runner"].get_tool(tool_type) for n, r in self._runner_ctxs.items()} - - def _finalize_tool( - checker: callable, - post_batch: Optional[callable] = None, - post_iter: Optional[callable] = None, - ): - while any(not checker(t) for t in tools.values()): - assert data_loader, "data_loader should be given to make plan for " + tool_type - for inputs in data_loader(): - outputs = self.run(inputs, ret_type="native") - if post_batch: - for t in tools.values(): - post_batch(t, outputs) - if all(checker(t) for t in tools.values()): - break - if post_iter: - for t in tools.values(): - post_iter(t) - return {n: t.finalize() for n, t in tools.items()} - - if tool_type == ToolType.PRUNER: - plans = _finalize_tool(lambda t: t.pruned) - elif tool_type == ToolType.QUANTIZER: - plans = _finalize_tool(lambda t: t.calibrated, post_iter=lambda t: t.calibrate()) - elif tool_type == ToolType.DISTILLER: - plans = _finalize_tool( - lambda t: t.distilled, - post_batch=lambda t, outputs: t.learn(outputs), - post_iter=lambda t: t.distill(), - ) - elif tool_type == ToolType.TRACKER: - plans = _finalize_tool(lambda t: t.tracked) - else: - plans = {n: t.finalize() for n, t in tools.items()} - plans_info = ", ".join([f"{n}({len(p)})" for n, p in plans.items()]) - self._logger.debug("Made %s plans for %s", plans_info, tool_type) - - def _redirect_run(self, *args, runner_name: str = "worker", **kwargs) -> Any: - """Redirect forward of model - - Parameters - ---------- - args: - The arguments. - runner_name: str - The runner name. - kwargs: - The kwargs. - - Returns - ------- - outputs: - The outputs. - """ - - assert runner_name in self._runner_ctxs, "Failed to create runner " + runner_name - inputs = self._to_msc_inputs(runner_name, *args, **kwargs) - for hook in self._hooks.get("pre_forward", []): - hook(runner_name, inputs) - outputs = self._run_ctx(self.get_runner_ctx(runner_name), inputs) - for hook in self._hooks.get("post_forward", []): - outputs = hook(runner_name, outputs) - return self._from_msc_outputs(runner_name, outputs) - - def _to_msc_inputs(self, runner_name: str, *args, **kwargs) -> List[Tuple[str, Any]]: - """Change inputs to msc format - - Parameters - ---------- - runner_name: str - The runner name. - args: - The arguments. - kwargs: - The kwargs. - - Returns - ------- - inputs: - The msc format inputs. - """ - - raise NotImplementedError("_to_msc_inputs is not implemented in " + str(self.__class__)) - - def _from_msc_outputs(self, runner_name: str, outputs: List[Tuple[str, Any]]) -> Any: - """Change inputs from msc format - - Parameters - ---------- - runner_name: str - The runner name. - outputs: list<(str, tensor)> - The msc format outputs. - - Returns - ------- - outputs: - The framework outputs. - """ - - raise NotImplementedError("_from_msc_outputs is not implemented in " + str(self.__class__)) - - def _run_ctx(self, runner_ctx: dict, inputs: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: - """Forward by runner context - - Parameters - ---------- - runner_ctx: dict - The runner context - inputs: list<(str, tensor)> - The inputs. - - Returns - ------- - outputs: list<(str, tensor)> - The outputs. - """ - - raise NotImplementedError("_run_ctx is not implemented in " + str(self.__class__)) - - def get_runner_ctx(self, runner_name: str) -> dict: - """Get the runner context - - Parameters - ---------- - runner_name: str - The runner name - - Returns - ------- - runner_cts: dict - The runner context. - """ - - assert runner_name in self._runner_ctxs, "Can not finc runner_context " + str(runner_name) - return self._runner_ctxs[runner_name] - - def train(self): - """Change status to train""" - - if not self._training: - self._training = True - for runner_ctx in self._runner_ctxs.values(): - if "runner" in runner_ctx: - runner_ctx["runner"].train() - - def eval(self): - """Change status to eval""" - - if self._training: - self._training, self._trained = False, True - for runner_ctx in self._runner_ctxs.values(): - if "runner" in runner_ctx: - runner_ctx["runner"].eval() - - def jit_mark(self, msg: str): - """Mark the message with jit info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"JIT({self.framework}) {msg}" - - @property - def trained(self): - return self._trained - - @property - def jit_model(self): - return self._jit_model - - @property - def framework(self): - return MSCFramework.MSC - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - return True diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py deleted file mode 100644 index a7b4ae29edd6..000000000000 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ /dev/null @@ -1,1591 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.runtime.runner""" - -import json -import logging -import os -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.codegen import to_relax -from tvm.contrib.msc.core.frontend import from_relax -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.tools import BaseTool, ToolScope, ToolType, create_tool, remove_tools -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework - -from .hook import load_runner_hook - - -class BaseRunner: - """Basic runner of MSC - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - params: dict of - The parameters of the IRModule. - tools_config: list - The config of MSC Tools. - translate_config: dict - The config for translate IRModule to MSCGraph. - codegen_config: dict - The config for build MSCGraph to runnable model. - build_config: dict - The config for build runnable. - device: str - The device to build runnable. - training: bool - Whether compile model to trainable. - stage: str - The stage of runner. - plugin: PluginManager - The plugin manager. - name: str - The name of the runner - debug_level: int - The debug level. - logger: logging.Logger - The logger - """ - - def __init__( - self, - mod: tvm.IRModule, - tools_config: Optional[List[dict]] = None, - translate_config: Optional[Dict[str, str]] = None, - generate_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - device: str = "cpu", - training: bool = False, - stage: str = "default", - plugin: Any = None, - name: str = "main", - debug_level: int = 0, - logger: Optional[logging.Logger] = None, - ): - self._mod = mod - if tools_config: - self._tools_type = [t["tool_type"] for t in tools_config] - self._tools_config = { - t["tool_type"]: msc_utils.copy_dict(t["tool_config"]) for t in tools_config - } - else: - self._tools_type, self._tools_config = [], {} - self._translate_config = msc_utils.copy_dict(translate_config) - self._generate_config = msc_utils.copy_dict(generate_config) - self._build_config = msc_utils.copy_dict(build_config) - self._device = device if self.support_device(device) else "cpu" - self._stage = stage - self._plugin = plugin - self._name = name - self._debug_level = debug_level - self._training, self._trained = training, training - self._logger = logger or msc_utils.get_global_logger() - self._logger.info(msc_utils.msg_block(self.runner_mark("SETUP"), self.setup())) - self._tools = self.setup_tools() - - def setup(self) -> dict: - """Setup the runner - - Returns - ------- - info: dict - The setup info. - """ - - if "build_folder" not in self._generate_config: - self._generate_config["build_folder"] = msc_utils.get_build_dir() - self._graphs, self._weights = [], {} - self._model, self._model_info = None, {} - self._runnable = None - if self._plugin: - self._update_codegen({"use_plugin": True}) - return { - "tools": {k: v.get("tool_style", "default") for k, v in self._tools_config.items()}, - "plugin": self._plugin, - "translate_config": self._translate_config, - "generate_config": self._generate_config, - "build_config": self._build_config, - "device": self._device, - "name": self._name, - "debug_level": self._debug_level, - } - - def setup_tools(self) -> Dict[str, BaseTool]: - """Setup tools - - Returns - ------- - tools: dict - The tools. - """ - - tools = {} - if self._tools_type: - self._update_codegen({"use_tools": True, "tools_tag": self._name}) - for t_type in self._tools_type: - tools[t_type] = create_tool( - self.framework, - t_type, - self._name, - training=self._training, - stage=self._stage, - **self._tools_config[t_type], - ) - return tools - - def change_stage(self, stage: str): - """Change the stage of runner and tools""" - - self._stage = stage - for tool in self._tools.values(): - tool.change_stage(stage) - - def change_logger(self, logger: logging.Logger): - """Change the logger of runner and tools""" - - self._logger = logger - for tool in self._tools.values(): - tool.change_logger(logger) - - def build( - self, - cache_dir: msc_utils.MSCDirectory = None, - force_build: bool = False, - disable_tools: Optional[List[str]] = None, - ) -> Any: - """Build the runnable object - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - force_build: bool - Whether to force build the runner. - disable_tools: list - The tool types to be disabled. - - Returns - ------- - runnable: Any - The runnable object. - """ - - if force_build: - self._graphs, self._weights = [], {} - self._model, self._model_info = None, {} - self._runnable = None - if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")): - cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json")) - else: - cache_info = {} - - # set tools to reset - if disable_tools: - tools = [t for t in self.get_tools() if t.tool_type not in disable_tools] - else: - tools = None - - build_msg = "" - # Load graphs from cache - if not self._graphs and cache_info.get("graphs"): - self._graphs = self._load_graphs(cache_dir, cache_info["graphs"]) - assert "weights" in cache_info, "Missing weights in cache_info" - with open(cache_dir.relpath(cache_info["weights"]), "rb") as f: - self._weights = tvm.runtime.load_param_dict(f.read()) - build_msg += "Load " - - # Translate graphs from module - if not self._graphs: - self._graphs, self._weights = self.translate() - build_msg += "Translate " - build_msg += f"{len(self._graphs)} graphs {len(self._weights)} weights -> " - - # Load model from cache - if not self._model and cache_info.get("model"): - self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) - self._model = self._load_model(cache_dir, cache_info["model"]) - build_msg += "Load " - - # Generate model - if not self._model: - distiller = self.get_tool(ToolType.DISTILLER) - if distiller and not distiller.distilled: - build_root = self._generate_config["build_folder"] - - def _build_scope_model(scope: str, apply_hooks: bool): - self._update_codegen({"tools_scope": scope}) - self._generate_config["build_folder"] = build_root.create_dir(scope) - return self.generate_model(apply_hooks=apply_hooks) - - # Generate distill model - teacher_model = _build_scope_model(ToolScope.TEACHER, False) - self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) - student_model = _build_scope_model(ToolScope.STUDENT, True) - self._model = distiller.build_model(teacher_model, student_model) - else: - # Generate normal model - self._graphs, self._weights = self.reset_tools(tools=tools, cache_dir=cache_dir) - self._model = self.generate_model() - build_msg += "Generate " - - # Add tool message - if self._tools: - build_msg += "model with tools " + str(",".join(self._tools.keys())) + " -> " - else: - build_msg += "model without tools -> " - - # Inspect model - self._model_info = self._inspect_model() - if self._debug_level >= 2: - self._logger.debug( - msc_utils.msg_block(self.runner_mark("MODEL_INFO"), self._model_info) - ) - - # Load runnable from cache - if not self._runnable and cache_info.get("runnable"): - self._runnable = self._load_runnable(cache_dir, cache_info["runnable"]) - build_msg += "Load " - - # Build runnable - if not self._runnable: - self._runnable = self.build_runnable() - build_msg += "Build " - build_msg += "runnable({}, {}) on {}".format( - self.framework, "train" if self._training else "eval", self._device - ) - self._logger.info(self.runner_mark(build_msg)) - return self._runnable - - def run( - self, inputs: Union[List[np.ndarray], Dict[str, np.ndarray]], ret_type="dict" - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Run the model to get outputs - - Parameters - ------- - inputs: list or dict - The inputs in list or dict. - ret_type: str - The return type list| dict - - Returns - ------- - outputs: dict - The outputs in dict. - """ - - in_names = [i["name"] for i in self.get_inputs()] - inputs = msc_utils.format_datas(inputs, in_names, style="dict") - outputs = self._call_runnable(self._runnable, inputs, self._device) - if ret_type == "native": - return outputs - out_names = [o["name"] for o in self.get_outputs()] - return msc_utils.format_datas(outputs, out_names, style=ret_type) - - def save_cache( - self, - cache_dir: msc_utils.MSCDirectory, - save_model: bool = True, - save_runnable: bool = True, - save_tools: bool = True, - ): - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - save_model: bool - Whether to save model. - save_runnable: bool - Whether to save runnable. - save_tools: bool - Whether to save tools. - """ - - cache_info = {"graphs": self._save_graphs(cache_dir), "weights": "graph_weights.bin"} - with cache_dir: - with open(cache_info["weights"], "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(self._weights)) - if save_model and cache_info.get("graphs"): - cache_info["model"] = self._save_model(cache_dir) - if save_runnable and cache_info.get("model"): - cache_info["runnable"] = self._save_runnable(cache_dir) - if save_tools: - for t_type, tool in self._tools.items(): - cache_info[t_type] = tool.save_cache(cache_dir) - with open(cache_dir.relpath("cache_info.json"), "w") as f: - f.write(json.dumps(cache_info, indent=2)) - title = self.runner_mark("SAVE_CACHE") - self._logger.debug(msc_utils.msg_block(title, {"folder": cache_dir, "info": cache_info})) - - def translate( - self, apply_hooks: bool = True - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Translate IRModule to MSCgraphs - - Parameters - ------- - apply_hooks: bool - Whether to apply hooks. - - Returns - ------- - graphs: list - The translated graphs - weights: dict - The translated weights. - """ - - mod = self._mod - if apply_hooks: - for hook in self._translate_config.get("pre_hooks", []): - mod = self._apply_hook("before translate", hook, mod) - graphs, weights = self._translate(mod) - if apply_hooks: - for hook in self._translate_config.get("post_hooks", []): - graphs, weights = self._apply_hook("after translate", hook, graphs, weights) - return graphs, weights - - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Translate IRModule to MSCgraphs - - Parameters - ------- - mod: tvm.IRModule - The module to be translated. - - Returns - ------- - graphs: list - The translated graphs - weights: dict - The translated weights. - """ - - raise NotImplementedError("_translate is not implemented for " + str(self.__class__)) - - def reset_tools( - self, - graphs: Optional[List[MSCGraph]] = None, - weights: Optional[List[Dict[str, tvm.runtime.Tensor]]] = None, - tools: Optional[List[BaseTool]] = None, - cache_dir: msc_utils.MSCDirectory = None, - ): - """Reset the tools - - Parameters - ------- - graphs: list - The msc graphs. - weights: list> - The weights. - tools: list - The tools. - cache_dir: MSCDirectory - cache path for save/load info. - - Returns - ------- - graphs: list - The msc graphs. - weights: list> - The weights. - """ - - graphs = graphs or self._graphs - weights = weights or self._weights - if tools is None: - tools = list(self.get_tools()) - for tool in tools: - graphs, weights = tool.reset(graphs, weights, cache_dir) - return graphs, weights - - def generate_model(self, apply_hooks: bool = True) -> Any: - """Codegen the model according to framework - - Parameters - ------- - apply_hooks: bool - Whether to apply hooks. - - Returns - ------- - model: Any - The meta model - """ - - graphs, weights = self._graphs, self._weights - if apply_hooks: - for hook in self._generate_config.get("pre_hooks", []): - graphs, weights = self._apply_hook("before generate", hook, graphs, weights) - model = self._generate_model(graphs, weights) - if apply_hooks: - for hook in self._generate_config.get("post_hooks", []): - model = self._apply_hook("after generate", hook, model) - return model - - def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Any: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - model: Any - The meta model - """ - - raise NotImplementedError("_load is not implemented for " + str(self.__class__)) - - def build_runnable(self, apply_hooks: bool = True) -> Any: - """Build runnable object - - Parameters - ------- - apply_hooks: bool - Whether to apply hooks. - - Returns - ------- - runnable: Any - The runnable - """ - - model = self._model - if apply_hooks: - for hook in self._build_config.get("pre_hooks", []): - model = self._apply_hook("before build", hook, model) - runnable = self._build_runnable(model) - if apply_hooks: - for hook in self._build_config.get("post_hooks", []): - runnable = self._apply_hook("after build", hook, runnable) - return runnable - - def _build_runnable(self, model: Any) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - - Returns - ------- - runnable: Any - The runnable - """ - - raise NotImplementedError("_build_runnable is not implemented for " + str(self.__class__)) - - def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - """Export the module from graphs - - Parameters - ---------- - folder: MSCDirectory - The export folder. - - Returns - ------- - module: IRModule - The exported module - """ - - raise NotImplementedError("export_module is not implemented for " + str(self.__class__)) - - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the runnable - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The runnable info. - """ - - raise NotImplementedError("export_runnable is not implemented for " + str(self.__class__)) - - def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the graphs - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The graphs info. - """ - - raise NotImplementedError("export_graphs is not implemented for " + str(self.__class__)) - - def train(self): - """Change status to train""" - - if not self._training: - self._training = True - for tool in self.get_tools(): - tool.train() - self._train() - - def _train(self): - """Change status to train""" - - self._runnable = self.build_runnable() - - def eval(self): - """Change status to eval""" - - if self._training: - self._training, self._trained = False, True - for tool in self.get_tools(): - tool.eval() - self._eval() - - def _eval(self): - """Change status to eval""" - - self._runnable = self.build_runnable() - - def get_tool_config(self, tool_type: str) -> dict: - """Get tool by type - - Parameters - ------- - tool_type: str - The type of the tool prune| quantize| distill... - - Returns - ------- - config: dict - The tool config. - """ - - return self._tools_config.get(tool_type) - - def get_tool(self, tool_type: str) -> BaseTool: - """Get tool by type - - Parameters - ------- - tool_type: str - The type of the tool prune| quantize| distill... - - Returns - ------- - tool: BaseTool - The saved tool. - """ - - return self._tools.get(tool_type) - - def get_tools(self) -> Iterable[BaseTool]: - """Get all saved tools by tag - - Returns - ------- - tools: iterable - The saved tools. - """ - - for t_type in ToolType.all_types(): - tool = self.get_tool(t_type) - if tool: - yield tool - - def make_plan(self, tool_type: str, data_loader: Any = None) -> str: - """Execute tool and get plan - - Parameters - ------- - tool_type: str - The tool type, should be in ToolType - data_loader: - The data loader. - - Returns - ------- - plan_file: str - The saved plan file. - """ - - def _finalize_tool( - checker: callable, - post_batch: Optional[callable] = None, - post_iter: Optional[callable] = None, - ): - tool = self.get_tool(tool_type) - while not checker(tool): - assert data_loader, "data_loader should be given to make plan for " + tool_type - for inputs in data_loader(): - outputs = self.run(inputs, ret_type="native") - if post_batch: - post_batch(tool, outputs) - if checker(tool): - break - if post_iter: - post_iter(tool) - return tool.finalize() - - assert tool_type in self._tools, "Can not find tool " + str(tool_type) - if tool_type == ToolType.PRUNER: - plan = _finalize_tool(lambda t: t.pruned) - elif tool_type == ToolType.QUANTIZER: - plan = _finalize_tool(lambda t: t.calibrated, post_iter=lambda t: t.calibrate()) - elif tool_type == ToolType.DISTILLER: - plan = _finalize_tool( - lambda t: t.distilled, - post_batch=lambda t, outputs: t.learn(outputs), - post_iter=lambda t: t.distill(), - ) - elif tool_type == ToolType.TRACKER: - plan = _finalize_tool(lambda t: t.tracked) - else: - plan = self.get_tool(tool_type).finalize() - self._logger.debug("Made %d plan for %s", len(plan), tool_type) - plan_file = self._tools_config[tool_type]["plan_file"] - if plan: - with open(plan_file, "w") as f: - f.write(json.dumps(plan, indent=2)) - return plan_file - - def _apply_hook(self, desc: str, hook_def: dict, *args, **kwargs) -> Any: - """Load a registered hook - - Parameters - ---------- - desc: str - The description of the hook - hook_def: dict - The function and config of the hook. - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - result: - The result - """ - - hook = load_runner_hook(hook_def) - self._logger.info("Apply %s hook:\n %s", desc, hook) - return hook.apply(self, *args, **kwargs) - - def _update_codegen(self, config: Dict[str, Any]): - """Update the codegen in generate_config - - Parameters - ------- - config: dict - The extra config for codegen. - """ - - if "codegen" not in self._generate_config: - self._generate_config["codegen"] = {} - codegen = self._generate_config["codegen"] - if isinstance(codegen, dict): - codegen.update(config) - elif isinstance(codegen, (list, tuple)): - for c in codegen: - c.update(config) - else: - raise TypeError("Unexpecet codegen config " + str(codegen)) - - def visualize(self, visual_dir: msc_utils.MSCDirectory, export_graph: bool = False): - """Visualize MSCGraphs - - Parameters - ------- - visual_dir: MSCDirectory - Visualize path for saving graph - export_graph: bool - Whether to export the graph - """ - - for graph in self._graphs: - graph.visualize(visual_dir.relpath(graph.name + ".prototxt")) - if export_graph: - with open(visual_dir.relpath(graph.name + "_graph.json"), "w") as f_graph: - f_graph.write(graph.to_json()) - for tool in self._tools.values(): - tool.visualize(visual_dir) - - def get_inputs(self) -> List[Dict[str, str]]: - """Get the inputs of the model - - Returns - ------- - inputs: list - The inputs info. - """ - - return self._model_info["inputs"] - - def get_outputs(self) -> List[Dict[str, str]]: - """Get the outputs of the model - - Returns - ------- - outputs: list - The outputs info. - """ - - return self._model_info["outputs"] - - def get_weights( - self, framework: Optional[str] = None, device: Optional[str] = None - ) -> Iterable[tvm.runtime.Tensor]: - """Get the weights from graphs - - Parameters - ------- - framework: str - The framework for weight. - device: str - The device for weight. - - Returns - ------- - weights: generator - The generator of weight datas. - """ - - device = device or self._device - for graph in self._graphs: - for weight in graph.get_weights(): - data = self._weights[weight.name] - if framework: - data = msc_utils.cast_array(data, framework, device) - yield data - - def get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: - """Get the runtime parameters - - Returns - ------- - params: dict - The parameters from runtime. - """ - - return self._get_runtime_params() - - def _get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: - """Get the runtime parameters - - Returns - ------- - params: dict - The parameters from runtime. - """ - - raise NotImplementedError( - "_get_runtime_params is not implemented for " + str(self.__class__) - ) - - def destory(self): - """Destory runner""" - - if self._model: - self._model = None - if self._runnable: - self._runnable = None - for tool in self.get_tools(): - tool.destory() - remove_tools(self._name) - - def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: - """Load MSCGraphs from cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache info. - - Returns - ------- - graphs: list - The translated graphs - """ - - raise NotImplementedError("_load_graphs is not implemented for " + str(self.__class__)) - - def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save MSCgraphs to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache info. - """ - - raise NotImplementedError("_save_graphs is not implemented for " + str(self.__class__)) - - def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> Any: - """Load the model from cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache info. - - Returns - ------- - model: Any - The meta model - """ - - raise NotImplementedError("_load_model is not implemented for " + str(self.__class__)) - - def _save_model(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save model to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache info. - """ - - # disable save model by default - return {} - - def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> Any: - """Load the runnable from cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache info. - - Returns - ------- - runnable: Any - The runnable - """ - - raise NotImplementedError("_load_runnable is not implemented for " + str(self.__class__)) - - def _save_runnable(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save runnable to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache info. - """ - - # disable save runnable by default - return {} - - def _inspect_model(self) -> dict: - """Inspect the model - - Returns - ------- - model_info: dict - The inspected model info - """ - - raise NotImplementedError("_inspect_model is not implemented for " + str(self.__class__)) - - def _call_runnable( - self, runnable: Any, inputs: Dict[str, np.ndarray], device: str - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Call the runnable to get outputs - - Parameters - ------- - model: - The runnable model. - inputs: dict - The inputs in dict. - device: str - The device. - - Returns - ------- - outputs: list or dict - The outputs in list or dict. - """ - - raise NotImplementedError("_call_runnable is not implemented for " + str(self.__class__)) - - def runner_mark(self, msg: Any) -> str: - """Mark the message with runner info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"RUNNER[{self._name}]({self.framework} @ {self._stage}) {msg}" - - @property - def stage(self): - return self._stage - - @property - def debug_level(self): - return self._debug_level - - @property - def trained(self): - return self._trained - - @property - def model(self): - return self._model - - @property - def runnable(self): - return self._runnable - - @property - def model_info(self): - return self._model_info - - @property - def device(self): - return self._device - - @property - def codegen_func(self): - raise NotImplementedError("codegen_func is not implemented for " + str(self.__class__)) - - @property - def framework(self): - return MSCFramework.MSC - - @classmethod - def load_native(cls, model: Any, config: dict) -> Tuple[Any, str, bool]: - """Load the native model - - Parameters - ------- - model: - The native model. - config: dict - The config for pipeline. - - Returns - ------- - model: - The loaded native model. - device: str - The device of the model. - training: - Whether the model is for training. - """ - - return model, "cpu", False - - @classmethod - def run_native( - cls, - model: Any, - inputs: Dict[str, np.ndarray], - input_names: List[str], - output_names: List[str], - warm_up: int = 10, - repeat: int = 0, - ) -> Tuple[Dict[str, np.ndarray], float]: - """Run the datas and get outputs - - Parameters - ------- - model: - The nativate model. - inputs: dict - The inputs in dict. - input_names: list - The input names. - output_names: list - The outut names. - warm_up: int - The warm_up num for profile. - repeat: int - The repeat num for profile. - - Returns - ------- - outputs: dict - The outputs in dict. - avg_time: float - The average time. - """ - - raise NotImplementedError("run_native is not implemented for " + str(cls)) - - @classmethod - def dump_nativate( - cls, model: Any, folder: msc_utils.MSCDirectory, dump_config: Optional[dict] = None - ) -> str: - """Dump the nativate model - - Parameters - ------- - model: - The native model. - folder: MSCDirectory - The export folder. - dump_config: dict - The dump config. - - Returns - ------- - export_path: str - The exported path - """ - - raise NotImplementedError("dump_nativate is not implemented for " + str(cls)) - - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - if stage not in config: - return config - if stage in (MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE): - run_config = config[stage].get("run_config", {}) - if "translate_config" not in run_config: - run_config["translate_config"] = {} - if "build" not in run_config["translate_config"]: - run_config["translate_config"]["build"] = {} - if "generate_config" not in run_config: - run_config["generate_config"] = {} - run_config["translate_config"]["build"]["input_aliases"] = [ - i[0] for i in config["inputs"] - ] - run_config["translate_config"]["build"]["output_aliases"] = config["outputs"] - config[stage]["run_config"] = run_config - return config - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - return True - - -class ModelRunner(BaseRunner): - """Model runner of MSC""" - - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Translate IRModule to MSCgraphs - - Parameters - ------- - mod: tvm.IRModule - The module to be translated. - - Returns - ------- - graphs: list - The translated graphs - weights: dict - The translated weights. - """ - - graph, weights = from_relax( - mod, - trans_config=self._translate_config.get("transform"), - build_config=self._translate_config.get("build"), - opt_config=self._translate_config.get("optimize"), - ) - return [graph], weights - - def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: - """Load MSCGraphs from cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache info. - - Returns - ------- - graphs: list - The translated graphs - """ - - assert "main" in cache_info, "main should be given in cache_info, get " + str(cache_info) - graph = MSCGraph.from_json(cache_dir.relpath(cache_info["main"]["graph"])) - return [graph] - - def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save MSCgraphs to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache info. - """ - - main_info = {"graph": self._graphs[0].name + "_graph.json"} - with cache_dir: - with open(main_info["graph"], "w") as f_graph: - f_graph.write(self._graphs[0].to_json()) - return {"main": main_info} - - def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Any: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - model: Any - The runnable model - """ - - return self.codegen_func( - graphs[0], - weights, - codegen_config=self._generate_config.get("codegen"), - print_config=self._generate_config.get("print"), - build_folder=self._generate_config["build_folder"], - plugin=self._plugin, - ) - - def _inspect_model(self) -> dict: - """Inspect the model - - Returns - ------- - model_info: dict - The inspected model info - """ - - return self._graphs[0].inspect() - - def export_module(self, folder: msc_utils.MSCDirectory) -> tvm.IRModule: - """Export the module from graphs - - Parameters - ---------- - folder: MSCDirectory - The export folder. - - Returns - ------- - module: IRModule - The exported module - """ - - build_folder = folder.create_dir("export_build", keep_history=False, cleanup=True) - module = to_relax( - self._graphs[0], self.get_runtime_params(), build_folder=build_folder, use_alias=False - ) - return module - - def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the graphs - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The graphs info. - """ - - graphs = {"main": folder.relpath(self._graphs[0].name + "_graph.json")} - with open(graphs["main"], "w") as f_graph: - f_graph.write(self._graphs[0].to_json()) - return graphs - - -class BYOCRunner(BaseRunner): - """BYOC runner of MSC""" - - def setup(self) -> dict: - """Setup the runner - - Returns - ------- - info: dict - The setup info. - """ - - self._byoc_mod, self._byoc_graph = None, None - self._executable = None - return super().setup() - - def visualize(self, visual_dir: msc_utils.MSCDirectory, export_graph: bool = False): - """Visualize MSCGraphs - - Parameters - ------- - visual_dir: MSCDirectory - Visualize path for saving graph - export_graph: bool - Whether to export the graph - """ - - super().visualize(visual_dir) - self._byoc_graph.visualize(visual_dir.relpath(self._byoc_graph.name + ".prototxt")) - if export_graph: - with open(visual_dir.relpath(self._byoc_graph.name + "_graph.json"), "w") as f_graph: - f_graph.write(self._byoc_graph.to_json()) - - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Translate IRModule to MSCgraphs - - Parameters - ------- - mod: tvm.IRModule - The module to be translated. - - Returns - ------- - graphs: list - The translated graphs - weights: dict - The translated weights. - """ - - self._byoc_mod, graphs, weights = self.partition_func( - mod, - trans_config=self._translate_config.get("transform"), - build_config=self._translate_config.get("build"), - ) - self._byoc_graph = _ffi_api.BuildFromRelax( - self._byoc_mod, "main", msc_utils.dump_dict(self._translate_config.get("build")) - ) - return graphs, weights - - def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: - """Load MSCgraphs from cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache info. - - Returns - ------- - graphs: list - The translated graphs - """ - - assert "byoc_mod" in cache_info, "byoc_mod should be given in cache_info, get " + str( - cache_info - ) - assert "byoc_graph" in cache_info, "byoc_graph should be given in cache_info, get " + str( - cache_info - ) - assert "sub_graphs" in cache_info, "sub_graphs should be given in cache_info, get " + str( - cache_info - ) - with open(cache_dir.relpath(cache_info["byoc_mod"])) as f: - self._byoc_mod = tvm.ir.load_json(f.read()) - graphs = [MSCGraph.from_json(cache_dir.relpath(g)) for g in cache_info["sub_graphs"]] - self._byoc_graph = MSCGraph.from_json(cache_dir.relpath(cache_info["byoc_graph"])) - return graphs - - def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save MSCgraphs to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache info. - """ - - sub_graphs = [g.name + "_graph.json" for g in self._graphs] - with cache_dir: - for graph, g_file in zip(self._graphs, sub_graphs): - with open(g_file, "w") as f_graph: - f_graph.write(graph.to_json()) - with open("byoc_graph.json", "w") as f: - f.write(self._byoc_graph.to_json()) - with open("byoc_module.json", "w") as f: - f.write(tvm.ir.save_json(self._byoc_mod)) - return { - "sub_graphs": sub_graphs, - "byoc_graph": "byoc_graph.json", - "byoc_mod": "byoc_module.json", - } - - def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Any: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - model: tvm.IRModule - The relax module - """ - - extra_option = self._generate_config.get("extra_option", {}) - extra_option["tool_tag"] = "" if self._stage == MSCStage.COMPILE else self._name - return self.codegen_func( - self._byoc_mod, - graphs, - weights, - codegen_configs=self._generate_config.get("codegen"), - print_configs=self._generate_config.get("print"), - extra_options=extra_option, - build_folder=self._generate_config["build_folder"], - output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()), - plugin=self._plugin, - ) - - def _build_runnable(self, model: Any) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - - Returns - ------- - runnable: Any - The runnable - """ - - model = tvm.relax.transform.LegalizeOps()(model) - if self._device == "cpu": - target = tvm.target.Target("llvm") - with tvm.transform.PassContext(opt_level=3): - self._executable = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(self._executable, tvm.cpu()) - elif self._device.startswith("cuda"): - target = tvm.target.Target("cuda") - with target: - model = tvm.s_tir.transform.DefaultGPUSchedule()(model) - with tvm.transform.PassContext(opt_level=3): - self._executable = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(self._executable, tvm.cuda()) - else: - raise NotImplementedError("Unsupported device " + str(self._device)) - return runnable - - def _call_runnable( - self, runnable: tvm.relax.VirtualMachine, inputs: Dict[str, np.ndarray], device: str - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Call the runnable to get outputs - - Parameters - ------- - runnable: tvm.relax.VirtualMachine - The virtual machine. - inputs: dict - The inputs in dict. - device: str - The device. - - Returns - ------- - outputs: list - The outputs in list. - """ - - input_names = [i["name"] for i in self.get_inputs()] - tvm_inputs = [ - msc_utils.cast_array(inputs[i], MSCFramework.TVM, device) for i in input_names - ] - return runnable["main"](*tvm_inputs) - - def _inspect_model(self) -> dict: - """Inspect the model - - Returns - ------- - model_info: dict - The inspected model info - """ - - if self._debug_level >= 2: - sub_graphs = {g.name: g.inspect for g in self._graphs} - title = self.runner_mark(f"SUBGRAPHS({len(sub_graphs)})") - self._logger.debug(msc_utils.msg_block(title, sub_graphs)) - return self._byoc_graph.inspect() - - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the runnable - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The runnable info. - """ - - export_lib = folder.relpath("lib.so") - self._executable.export_library(export_lib) - return { - "lib": export_lib, - "device": self.device, - "model_type": self.framework, - "abstract": self.model_info, - } - - def export_graphs(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the graphs - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The graphs info. - """ - - graphs = { - "byoc_graph": folder.relpath(self._byoc_graph.name + "_graph.json"), - "sub_graphs": {g.name: folder.relpath(g.name + "_graph.json") for g in self._graphs}, - } - with open(graphs["byoc_graph"], "w") as f: - f.write(self._byoc_graph.to_json()) - for graph in self._graphs: - with open(graphs["sub_graphs"][graph.name], "w") as f: - f.write(graph.to_json()) - return graphs - - @property - def partition_func(self): - raise NotImplementedError("partition_func is not implemented for " + str(self.__class__)) - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - return tvm.cuda(dev_id).exist - return False diff --git a/python/tvm/contrib/msc/core/tools/__init__.py b/python/tvm/contrib/msc/core/tools/__init__.py deleted file mode 100644 index eb841a415557..000000000000 --- a/python/tvm/contrib/msc/core/tools/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools""" - -from .tool import * -from .execute import * -from .prune import * -from .quantize import * -from .distill import * -from .track import * diff --git a/python/tvm/contrib/msc/core/tools/configer.py b/python/tvm/contrib/msc/core/tools/configer.py deleted file mode 100644 index a1d45b3145a0..000000000000 --- a/python/tvm/contrib/msc/core/tools/configer.py +++ /dev/null @@ -1,110 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.configer""" - -from typing import Optional, Union - -from tvm.contrib.msc.core import utils as msc_utils - -from .tool import ToolType - - -class ToolConfiger: - """Base configer for tool""" - - def config(self, raw_config: Optional[dict] = None) -> dict: - """Get the config - - Parameters - ---------- - raw_config: dict - The raw config. - - Returns - ------- - config: dict - The update config. - """ - - config = {} - if isinstance(raw_config, dict) and "gym_configs" in raw_config: - config["gym_configs"] = [self.config_gym(g) for g in raw_config.pop("gym_configs")] - if raw_config: - config["tool_config"] = self.update_tool(raw_config) - else: - config["tool_config"] = self.config_tool() - config.update(self.config_apply()) - return config - - def config_tool(self) -> dict: - """Get the default config of tool - - Returns - ------- - config: dict - The default config. - """ - - raise NotImplementedError("config_tool is not implemented in ToolConfiger") - - def update_tool(self, raw_config: dict) -> dict: - """Update tool config from raw_config - - Parameters - ---------- - raw_config: dict - The raw config. - - Returns - ------- - config: dict - The update config. - """ - - config = self.config_tool() - return msc_utils.update_dict(config, raw_config) - - def config_gym(self, gym_config: Union[dict, str]) -> dict: - """Config the gym - - Parameters - ---------- - gym_config: dict - The raw config. - - Returns - ------- - gym_config: dict - The update config. - """ - - raise NotImplementedError("config_gym is not implemented in ToolConfiger") - - def config_apply(self) -> dict: - """Get the config for apply - - Returns - ------- - config: dict - The apply config. - """ - - return {} - - @classmethod - def tool_type(cls): - return ToolType.BASE diff --git a/python/tvm/contrib/msc/core/tools/distill/__init__.py b/python/tvm/contrib/msc/core/tools/distill/__init__.py deleted file mode 100644 index 5e29fddfa7b6..000000000000 --- a/python/tvm/contrib/msc/core/tools/distill/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.distill""" - -from .distiller import * -from .method import * -from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/distill/configer.py b/python/tvm/contrib/msc/core/tools/distill/configer.py deleted file mode 100644 index 782a38030804..000000000000 --- a/python/tvm/contrib/msc/core/tools/distill/configer.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.distill.configer""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.configer import ToolConfiger -from tvm.contrib.msc.core.tools.tool import ToolType - - -class DistillConfiger(ToolConfiger): - """Configer for distill""" - - @classmethod - def tool_type(cls): - return ToolType.DISTILLER - - -@msc_utils.register_tool_configer -class DefaultDistillConfiger(DistillConfiger): - """Default configer for distill""" - - def config_tool(self) -> dict: - """Get the default config of tool - - Returns - ------- - config: dict - The default config. - """ - - return { - "plan_file": "msc_distiller.json", - "strategys": [ - { - "methods": {"mark": "loss_lp_norm"}, - "marks": ["loss"], - }, - ], - } - - @classmethod - def config_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py deleted file mode 100644 index 7fb6fbd398ca..000000000000 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ /dev/null @@ -1,266 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.distill.distiller""" - -import os -from typing import Any, Dict, List, Tuple - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolStrategy, ToolType - - -class BaseDistiller(BaseTool): - """Base distiller for all""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - self._max_iter = self._options.get("max_iter", 1) - self._save_step = self._options.get("save_step", 50) - if "weights_folder" in self._options: - self._weights_folder = msc_utils.msc_dir(self._options["weights_folder"]) - else: - self._weights_folder = msc_utils.get_weights_dir().create_dir("Distill") - self._weights_path = self._weights_folder.relpath(f"distill_{self._max_iter}.bin") - self._distilled = os.path.isfile(self._weights_path) - return super().setup() - - def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - self._current_iter, self._total_loss = 0, 0 - if self._distilled: - with open(self._weights_path, "rb") as f: - distilled_weights = tvm.runtime.load_param_dict(f.read()) - weights.update({k: v for k, v in distilled_weights.items() if k in weights}) - msg = f"Update {len(distilled_weights)} distilled weights" - self._logger.info(self.tool_mark(msg)) - return super()._reset(graphs, weights) - - def build_model(self, teacher: Any, student: Any) -> Any: - """Build the model with teacher and student - - Parameters - ---------- - teacher: Any - The teacher model - student: Any - The student model - - Returns - ------- - model: Any - The built model. - """ - - raise NotImplementedError("build_model is not implemented in BaseDistiller") - - def learn(self, loss: Any): - """Learn after forward - - Parameters - ---------- - loss: Any - The loss after forward - """ - - if self.on_debug(3, in_forward=False): - msg = f"Start learn[{self._current_iter}]" - self._logger.debug(self.tool_mark(msg)) - self._total_loss += float(self._learn(loss)) - - def _learn(self, loss: Any): - """Learn after forward - - Parameters - ---------- - loss: Any - The loss after forward - """ - - raise NotImplementedError("_learn is not implemented in BaseDistiller") - - def distill(self) -> Dict[str, Any]: - """Distill the knowledge - - Returns - ------- - weights: dict - The distilled weights. - """ - - weights = self._distill() - if self._current_iter >= self._max_iter or ( - self._current_iter > 0 and self._current_iter % self._save_step == 0 - ): - self._save_weights(weights) - if self._current_iter >= self._max_iter: - self._distilled = True - self._plan = {n: msc_utils.inspect_array(d, False) for n, d in weights.items()} - msg = f"Distill[{self._current_iter}] loss({self._forward_cnt} batch) {self._total_loss}" - self._logger.info(self.tool_mark(msg)) - self._current_iter += 1 - self._total_loss, self._forward_cnt = 0, 0 - return weights - - def _distill(self) -> Dict[str, Any]: - """Distill the knowledge - - Returns - ------- - weights: dict - The distilled weights. - """ - - raise NotImplementedError("_distill is not implemented in BaseDistiller") - - def _save_weights(self, weights: Dict[str, Any]): - """Save the distilled weights - - Parameters - ---------- - weights: dict - The distilled weights. - """ - - weights = {n: tvm.runtime.tensor(msc_utils.cast_array(d)) for n, d in weights.items()} - weights_path = self._weights_folder.relpath(f"distill_{self._current_iter}.bin") - with open(weights_path, "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(weights)) - if self._debug_level >= 2: - msg = f"Save weights[{self._current_iter}] to {weights_path}" - self._logger.debug(self.tool_mark(msg)) - - def _support_scope(self, scope: str) -> bool: - """Check if the scope si supported - - Parameters - ------- - scope: str - The scope mark, should be null or ToolScope - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - return True - - def _process_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if self._distilled: - return tensor - return self._distill_tensor(tensor, name, consumer, scope, strategys) - - def _distill_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if name not in self._plan: - self._plan[name] = {} - plan = {} - for strategy in strategys: - plan.update(strategy(self, tensor, name, consumer, scope)) - self._plan[name][scope] = plan - return tensor - - @property - def distilled(self): - return self._distilled - - @classmethod - def tool_type(cls): - return ToolType.DISTILLER - - @classmethod - def exportable(cls): - return False - - -@msc_utils.register_tool -class DefaultDistiller(BaseDistiller): - @classmethod - def tool_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/distill/method.py b/python/tvm/contrib/msc/core/tools/distill/method.py deleted file mode 100644 index 697a54577507..000000000000 --- a/python/tvm/contrib/msc/core/tools/distill/method.py +++ /dev/null @@ -1,75 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.tools.distill.method""" - -from typing import List - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class DistillMethod: - """Default distill method""" - - @classmethod - def loss_lp_norm( - cls, - distiller: BaseTool, - t_outputs: List[np.ndarray], - s_outputs: List[np.ndarray], - power: int = 2, - ): - """Calculate loss with mse - - Parameters - ---------- - distiller: BaseDistiller - The distiller - t_outputs: list - The teacher outputs. - s_outputs: list - The student outputs. - power: int - The power factor. - - Returns - ------- - loss: float - The loss. - """ - - loss = 0 - for t_out, s_out in zip(t_outputs, s_outputs): - loss += np.mean(np.power(np.abs(t_out - s_out), power)) - return loss - - @classmethod - def framework(cls): - return MSCFramework.MSC - - @classmethod - def tool_type(cls): - return ToolType.DISTILLER - - @classmethod - def method_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/execute.py b/python/tvm/contrib/msc/core/tools/execute.py deleted file mode 100644 index 959891e89ca7..000000000000 --- a/python/tvm/contrib/msc/core/tools/execute.py +++ /dev/null @@ -1,402 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.execute""" - -from functools import wraps -from typing import Any, Dict, Iterable, List - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCKey, MSCMap - -from .tool import BaseTool, ToolType - - -def _get_tool_key(tool_type: str) -> str: - """Get the key according to tool_type - - Parameters - ------- - tool_type: str - The type of the tool prune| quantize| distill... - - Returns - ------- - tool_key: str - The tool key. - """ - - if tool_type == ToolType.PRUNER: - return MSCKey.PRUNERS - if tool_type == ToolType.QUANTIZER: - return MSCKey.QUANTIZERS - if tool_type == ToolType.DISTILLER: - return MSCKey.DISTILLERS - if tool_type == ToolType.TRACKER: - return MSCKey.TRACKERS - raise TypeError("Unexpected tool type " + str(tool_type)) - - -def add_tool(tool: BaseTool, tool_type: str, tag: str = "main"): - """Add tool by type and tag - - Parameters - ------- - tool: BaseTool - The tool. - tool_type: str - The type of the tool prune| quantize| distill... - tag: str - The tag of the tool. - """ - - tool_key = _get_tool_key(tool_type) - tools = MSCMap.get(tool_key, {}) - tools[tag] = tool - MSCMap.set(tool_key, tools) - return tool - - -def get_tool_cls(framework: str, tool_type: str, config: dict) -> BaseTool: - """Get the tool class - - Parameters - ------- - framework: str - The framework for implement - tool_type: str - The type of the tool prune| quantize| distill... - config: dict - The config of tool. - """ - - tool_style = config.pop("tool_style") if "tool_style" in config else "default" - tool_cls = msc_utils.get_registered_tool(framework, tool_type, tool_style) - assert tool_cls, f"Can not find tool class for {tool_type}:{tool_style} @ {framework}" - return tool_cls - - -def create_tool(framework: str, tool_type: str, tag: str = "main", **config) -> BaseTool: - """Create tool by type, config and tag - - Parameters - ------- - framework: str - The framework for implement - tool_type: str - The type of the tool prune| quantize| distill... - tag: str - The tag of the tool. - config: dict - The config of tool. - """ - - tool_cls = get_tool_cls(framework, tool_type, config) - return add_tool(tool_cls(tag, **config), tool_type, tag) - - -def get_tool(tool_type: str, tag: str = "main") -> BaseTool: - """Get tool by type and tag - - Parameters - ------- - tool_type: str - The type of the tool prune| quantize| distill... - tag: str - The tag of the tool. - - Returns - ------- - tool: BaseTool - The saved tool. - """ - - tool_key = _get_tool_key(tool_type) - tools = MSCMap.get(tool_key, {}) - return tools.get(tag) - - -def get_tools(tag: str = "main") -> Iterable[BaseTool]: - """Get all saved tools by tag - - Parameters - ------- - tag: str - The tag of the tool. - - Returns - ------- - tools: iterable - The saved tools. - """ - - for t_type in ToolType.all_types(): - tool = get_tool(t_type, tag) - if tool: - yield tool - - -def remove_tool(tool_type: str, tag: str = "main"): - """Remove tool by type and tag - - Parameters - ------- - tool_type: str - The type of the tool prune| quantize| distill... - tag: str - The tag of the tool. - """ - - tool_key = _get_tool_key(tool_type) - tools = MSCMap.get(tool_key, {}) - if tag in tools: - tools.pop(tag) - MSCMap.set(tool_key, tools) - - -def remove_tools(tag: str = "main"): - """Remove all saved tools by tag - - Parameters - ------- - tag: str - The tag of the tool. - - Returns - ------- - tools: iterable - The saved tools. - """ - - for t_type in ToolType.all_types(): - remove_tool(t_type, tag) - - -def process_tensor(tensor: Any, name: str, consumer: str, scope: str, tag: str = "main") -> Any: - """Process tensor with tools - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null - tag: str - The tag of the tool. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - for tool in get_tools(tag): - tensor = tool.process_tensor(tensor, name, consumer, scope) - return tensor - - -@tvm.register_global_func("msc_tool.codegen_tensor") -def codegen_tensor( - tensor_ctx: Dict[str, str], name: str, consumer: str, scope: str, tag: str = "main" -) -> List[str]: - """Codegen processed tensor describe with tools - - Parameters - ------- - tensor_ctx: dict - Tensor describe items. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null - tag: str - The tag of the tool. - - Returns - ------- - processed: list - The tensor describe for processed tensor. - """ - - tensor_ctx = {**dict(tensor_ctx), "processed": []} - tensor_ctx = process_tensor(dict(tensor_ctx), name, consumer, scope, tag) - return tensor_ctx["processed"] - - -def wrap_step(step: str, tag: str = "main") -> callable: - """Wrapper for tool execution - - Parameters - ------- - step: str - The step for tool execution build| forward - tag: str - The tag of the tool. - - Returns - ------- - decorate: callable - The decorate. - """ - - def decorate(func): - @wraps(func) - def wrapper(*args, **kwargs): - for tool in get_tools(tag): - if step == "build": - tool.execute_before_build(*args, **kwargs) - elif step == "forward": - tool.execute_before_forward(*args, **kwargs) - else: - raise TypeError("Unexpected step " + str(step)) - output = func(*args, **kwargs) - for tool in get_tools(tag): - if step == "build": - output = tool.execute_after_build(output) - elif step == "forward": - output = tool.execute_after_forward(output) - else: - raise TypeError("Unexpected step " + str(step)) - return output - - return wrapper - - return decorate - - -def execute_step(step: str, *args, **kwargs): - """Execute tools for a step - - Parameters - ------- - step: str - The step for tool execution build| forward - args: list - The arguments for model build. - kwargs: dict - The key word arguments for model build. - """ - - if step in ("before_build", "before_forward"): - output = None - else: - assert len(args) == 1 and not kwargs, ( - f"after step only accept 1 argument, get args {args}, kwargs {kwargs}" - ) - output = args[0] - tag = kwargs.pop("tag") if "tag" in kwargs else "main" - for tool in get_tools(tag): - if step == "before_build": - tool.execute_before_build(*args, **kwargs) - elif step == "before_forward": - tool.execute_before_forward(*args, **kwargs) - elif step == "after_build": - output = tool.execute_after_build(output) - elif step == "after_forward": - output = tool.execute_after_forward(output) - else: - raise TypeError("Unexpected step " + str(step)) - return output - - -def _execute_step_with_context( - step_ctx: Dict[str, Any], step: str, graph_name: str, tag: str = "main" -) -> Dict[str, Any]: - """Execute step with contect - - Parameters - ------- - step_ctx: dict - The step context. - step: str - The step for tool execution build| forward - graph_name: str - The graph name. - tag: str - The tag of the tool. - - Returns - ------- - step_ctx: dict - The processed step context. - """ - - for tool in get_tools(tag): - if step == "before_build": - tool.execute_before_build(step_ctx, graph_name=graph_name) - elif step == "before_forward": - tool.execute_before_forward(step_ctx, graph_name=graph_name) - elif step == "after_build": - step_ctx = tool.execute_after_build(step_ctx) - elif step == "after_forward": - step_ctx = tool.execute_after_forward(step_ctx) - else: - raise TypeError("Unexpected step " + str(step)) - return step_ctx - - -@tvm.register_global_func("msc_tool.codegen_step") -def codegen_step( - step_ctx: Dict[str, str], step: str, graph_name: str, tag: str = "main" -) -> List[str]: - """Codegen step codes - - Parameters - ------- - step_ctx: dict - The step describe items. - step: str - The step for tool execution build| forward - graph_name: str - The graph name. - tag: str - The tag of the tool. - - Returns - ------- - processed: list - The tensor describe for processed tensor. - """ - - step_ctx = {**dict(step_ctx), "processed": []} - step_ctx = _execute_step_with_context(step_ctx, step, graph_name, tag) - return step_ctx["processed"] - - -@tvm.register_global_func("msc_tool.callback_step") -def callback_step(step_ctx: Dict[str, Any], step: str, graph_name: str = "main", tag: str = "main"): - """Execute tools for a step - - Parameters - ------- - step_ctx: dict - The step context. - step: str - The step for tool execution build| forward - graph_name: str - The graph name. - tag: str - The tag of the tool. - """ - - _execute_step_with_context(step_ctx, step, graph_name, tag) diff --git a/python/tvm/contrib/msc/core/tools/prune/__init__.py b/python/tvm/contrib/msc/core/tools/prune/__init__.py deleted file mode 100644 index 80dbb4a5be86..000000000000 --- a/python/tvm/contrib/msc/core/tools/prune/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.prune""" - -from .pruner import * -from .method import * -from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/prune/configer.py b/python/tvm/contrib/msc/core/tools/prune/configer.py deleted file mode 100644 index 9348f4dfad34..000000000000 --- a/python/tvm/contrib/msc/core/tools/prune/configer.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.prune.configer""" - -from typing import Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.configer import ToolConfiger -from tvm.contrib.msc.core.tools.tool import ToolType - - -class PruneConfiger(ToolConfiger): - """Configer for prune""" - - def config_gym(self, raw_config: Union[dict, str]) -> dict: - """Config the gym - - Parameters - ---------- - gym_config: dict - The raw config. - - Returns - ------- - gym_config: dict - The update config. - """ - - if isinstance(raw_config, dict): - return raw_config - if raw_config == "default": - return { - "env": { - "executors": { - "action_space": { - "method": "action_prune_density", - "start": 0.2, - "end": 0.8, - "step": 0.1, - } - }, - }, - "agent": {"role_type": "search.grid", "executors": {}}, - } - else: - raise TypeError("Unexpected gym config " + str(raw_config)) - - @classmethod - def tool_type(cls): - return ToolType.PRUNER - - -@msc_utils.register_tool_configer -class DefaultPruneConfiger(PruneConfiger): - """Default configer for prune""" - - def config_tool(self) -> dict: - """Get the default config of tool - - Returns - ------- - config: dict - The default config. - """ - - return { - "plan_file": "msc_pruner.json", - "strategys": [ - { - "methods": { - "weights": {"method_name": "per_channel", "density": 0.8}, - "output": {"method_name": "per_channel", "density": 0.8}, - } - } - ], - } - - @classmethod - def config_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/prune/method.py b/python/tvm/contrib/msc/core/tools/prune/method.py deleted file mode 100644 index 1db764254f49..000000000000 --- a/python/tvm/contrib/msc/core/tools/prune/method.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.tools.prune.method""" - -from typing import List - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class PruneMethod: - """Default prune method""" - - @classmethod - def prune_axis(cls, data: np.ndarray, axis: int, indices: List[int]) -> np.ndarray: - """Delete indices on axis - - Parameters - ---------- - data: np.ndarray - The source data. - axis: int - The axis to prune - indices: list - The indices to be pruned - - Returns - ------- - data: np.ndarray - The pruned data. - """ - - left_datas = [ - d for idx, d in enumerate(np.split(data, data.shape[axis], axis)) if idx in indices - ] - return np.concatenate(left_datas, axis=axis) - - @classmethod - def per_channel( - cls, - pruner: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - in_axis: int, - out_axis: int, - in_indices: List[int], - density: float, - stride: int = 8, - ) -> np.ndarray: - """Prune the data - - Parameters - ---------- - pruner: BasePruner - The pruner - data: np.ndarray - The source data. - name: str - The name of the weight. - consumer: str - The name of the consumer. - in_axis: int - The input axis - out_axis: int - The output axis - in_indices: list - The input indices to be pruned - density: float - The density to prune - stride: int - The prune stride - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - config = {"in_indices": in_indices, "out_indices": []} - if density == 1: - return config - if len(in_indices) > 0: - data = cls.prune_axis(data, in_axis, in_indices) - out_dim = data.shape[out_axis] - left_num = int(((density * out_dim + stride) // stride) * stride) - axis_sum = [np.abs(d).sum() for d in np.split(data, out_dim, out_axis)] - rank = np.argsort(np.array(axis_sum)) - config["out_indices"] = rank[-left_num:].tolist() - return config - - @classmethod - def framework(cls): - return MSCFramework.MSC - - @classmethod - def tool_type(cls): - return ToolType.PRUNER - - @classmethod - def method_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py deleted file mode 100644 index 63c3db381d1a..000000000000 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ /dev/null @@ -1,536 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501 -"""tvm.contrib.msc.core.tools.prune.pruner""" - -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor, WeightJoint -from tvm.contrib.msc.core.tools.tool import ToolStrategy, ToolType, WeightTool -from tvm.contrib.msc.core.utils.message import MSCStage - -from .method import PruneMethod - - -class BasePruner(WeightTool): - """Base pruner for all""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - if not self._plan: - self.change_stage(MSCStage.PRUNE) - return super().setup() - - def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: - """Get the weight types from options - - Returns - ------- - main_wtypes: dict> - The main weight types. - relation_wtypes: dict - The relation weight types - """ - - if "main_wtypes" in self._options: - main_wtypes = self._options["main_wtypes"] - else: - main_wtypes = { - "constant": ["const"], - "nn.conv2d": ["weight"], - "msc.conv2d_bias": ["weight"], - "msc.linear": ["weight"], - "msc.linear_bias": ["weight"], - } - - if "relation_wtypes" in self._options: - relation_wtypes = self._options["relation_wtypes"] - else: - relation_wtypes = { - "concatenate": "multi_inputs", - "reshape": "reshape", - "add": "passby", - "substract": "passby", - "multiply": "passby", - "divide": "passby", - } - return main_wtypes, relation_wtypes - - def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: - """Parse the strategy to get valid strategy - - Parameters - ------- - strategy_list: list - The given strategys. - - Returns - ------- - strategys: dict - The parsed strategy. - """ - - if self._stage != MSCStage.PRUNE: - return {} - - def _update_stages(strategy): - if "stages" not in strategy: - strategy["stages"] = [MSCStage.PRUNE] - return strategy - - return super()._parse_strategys([_update_stages(s) for s in strategy_list]) - - def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - self._unpruned_tensors = {} - self._meta_weights = weights - graphs, weights = super()._reset(graphs, weights) - if self._plan and self._enabled: - return self.prune_graphs(graphs, weights) - return graphs, weights - - def _execute_before_build(self, *args, **kwargs): - """Execute before model build - - Parameters - ---------- - args: list - The arguments for model build. - kwargs: dict - The key word arguments for model build. - """ - - self._unpruned_tensors = {} - super()._execute_before_build(*args, **kwargs) - - def _execute_after_build(self, output: Any) -> Any: - """Execute after model build - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - assert not self._unpruned_tensors, "Some tensors are not pruned " + str( - self._unpruned_tensors - ) - return super()._execute_after_build(output) - - def _check_tensor(self, name: str, consumer: str) -> bool: - """Check if the tensor should be processed - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - if not self.has_w_node(name): - return False - strategy = self._get_tensor_strategy(name, consumer) - if not strategy: - return False - return True - - def _process_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if name in self._plan: - return tensor - - self._prune_tensor(name, consumer, strategys) - lazy_pruned = set() - for lazy_name, info in self._unpruned_tensors.items(): - if info["lead_name"] in self._plan: - strategys = self._get_tensor_strategys(lazy_name, info["consumer"]) - self._prune_tensor(lazy_name, info["consumer"], strategys) - t_mark = ".".join([s.get_executor().name for s in strategys]) - self.debug_tensors( - lazy_name, consumer, t_mark, {"lazy": self.find_tensor(lazy_name)} - ) - lazy_pruned.add(lazy_name) - if lazy_pruned: - self._unpruned_tensors = { - k: v for k, v in self._unpruned_tensors.items() if k not in lazy_pruned - } - return tensor - - def _prune_tensor(self, name: str, consumer: str, strategys: List[ToolStrategy]) -> Any: - """Prune tensor - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - """ - - assert len(strategys) == 1, "pruner should only has 1 strategy, get " + str(strategys) - strategy = strategys[0] - - def _get_in_indices(w_node: WeightJoint) -> List[int]: - """Get input indices for weight node""" - if not w_node.parents: - return [] - if w_node.name in self._plan and "in_indices" in self._plan[w_node.name]: - return self._plan[w_node.name]["in_indices"] - assert all(p.name in self._plan for p in w_node.parents), ( - "Missing some parents in runtime config " + str(w_node) - ) - if len(w_node.parents) == 1: - return self._plan[w_node.parents[0].name]["out_indices"] - if w_node.parents[0].friends: - return self._plan[w_node.parents[0].friends[0].name]["out_indices"] - raise Exception("Unexpected w_node " + str(w_node)) - - def _prunable(w_node: WeightJoint) -> bool: - """Check if weight node is prunable""" - if strategy.get_config().get("density", 1) == 1: - return False - if w_node.get_attr("weight_strategy") != "main": - return False - if not w_node.children: - return False - childrens = list(w_node.children) - while childrens: - current = childrens.pop(0) - weight_strategy = current.get_attr("weight_strategy") - if weight_strategy == "main": - return True - childrens.extend(list(current.children)) - return False - - w_node = self.find_w_node(name) - in_axis, out_axis = self._get_io_axes(w_node) - if w_node.weight.dim_at(in_axis) == 1: - in_indices = [] - else: - in_indices = _get_in_indices(w_node) - self._plan[w_node.name] = {"in_indices": in_indices} - if w_node.friends and w_node != w_node.friends[0]: - lead_name = w_node.friends[0].name - if lead_name not in self._plan: - self._unpruned_tensors[name] = { - "lead_name": lead_name, - "consumer": consumer, - } - self._plan.pop(w_node.name) - return None - self._plan[w_node.name]["out_indices"] = self._plan[lead_name]["out_indices"] - elif _prunable(w_node): - self._plan[w_node.name] = strategy( - self, - self.get_meta_data(w_node.name), - w_node.name, - consumer, - in_axis=in_axis, - out_axis=out_axis, - in_indices=in_indices, - ) - elif w_node.get_attr("weight_strategy") == "follow": - self._plan[w_node.name]["out_indices"] = [] - elif w_node.get_attr("weight_strategy") == "passby": - self._plan[w_node.name]["out_indices"] = in_indices - else: - self._plan[w_node.name]["out_indices"] = [] - - def prune_graphs( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - def _prune_by_shape(tensor: MSCTensor, shape: List[int]): - return MSCTensor(tensor.name, tensor.dtype, tensor.layout.name, shape, tensor.alias) - - def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: Optional[int] = None): - shape = tensor.get_shape() - if channel_axis is None: - if self.has_w_node(tensor.name): - w_node = self.find_w_node(tensor.name) - _, channel_axis = self._get_io_axes(w_node) - else: - channel_axis = tensor.layout_of("C") - assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor) - shape[channel_axis] = dim - return _prune_by_shape(tensor, shape) - - pruned_graphs, pruned_weights = [], {} - pruned_cnt = 0 - for graph in graphs: - pruned_tensors = {} - for node in graph.get_nodes(): - for weight in node.get_weights().values(): - w_node, w_name = self.find_w_node(weight.name), weight.name - if w_name not in self._plan: - pruned_weights[w_name] = weights[w_name] - elif w_node.get_attr("pruned_shape", "") != "": - pruned_weights[w_name] = weights[w_name] - pruned_shape = [int(i) for i in w_node.get_attr("pruned_shape").split(",")] - assert pruned_shape == list(pruned_weights[w_name].shape), ( - f"pruned_shape {pruned_shape} mismatch with data shape {pruned_weights[w_name].shape}" - ) - else: - data = msc_utils.cast_array(weights[w_name]) - in_axis, out_axis = self._get_io_axes(self.find_w_node(w_name)) - w_config = self._plan[w_name] - if w_config["in_indices"]: - data = PruneMethod.prune_axis(data, in_axis, w_config["in_indices"]) - if w_config["out_indices"]: - data = PruneMethod.prune_axis(data, out_axis, w_config["out_indices"]) - pruned_tensors[w_name] = _prune_by_shape(weight, data.shape) - pruned_weights[w_name] = tvm.runtime.tensor(data) - w_node.set_attr( - "pruned_shape", - ",".join([str(i) for i in pruned_tensors[w_name].get_shape()]), - ) - pruned_cnt += 1 - if node.optype == "constant": - if node.weight_at("const").name not in pruned_tensors: - continue - ref_tensor = pruned_tensors[node.weight_at("const").name] - pruned_tensors[node.output_at(0).name] = ref_tensor.clone( - name=node.output_at(0).name - ) - elif node.optype in self._main_wtypes: - if node.weight_at("weight").name not in pruned_tensors: - continue - out = node.output_at(0) - if node.optype in ("msc.linear", "msc.linear_bias"): - channel_axis = out.ndim - 1 - else: - channel_axis = out.layout_of("C") - pruned_tensors[out.name] = _prune_by_channel( - out, - pruned_tensors[node.weight_at("weight").name].dim_at("O"), - channel_axis, - ) - elif node.optype in self._relation_wtypes: - for out in node.get_outputs(): - w_node = self.find_w_node(out.name) - if out.name not in self._plan or w_node.get_attr("pruned_shape", "") != "": - continue - pruned_tensors[out.name] = _prune_by_channel( - out, len(self._plan[out.name]["out_indices"]) - ) - w_node.set_attr( - "pruned_shape", - ",".join([str(i) for i in pruned_tensors[out.name].get_shape()]), - ) - elif node.get_inputs(): - ref_input = node.input_at(0) - if ref_input.name not in pruned_tensors or ref_input.layout_of("C") < 0: - continue - for out in node.get_outputs(): - if out.layout_of("C") < 0: - continue - pruned_tensors[out.name] = _prune_by_channel( - out, pruned_tensors[ref_input.name].dim_at("C") - ) - - def _is_pruned(tensor: MSCTensor, graph: MSCGraph) -> bool: - return tensor.get_shape() != graph.find_tensor(tensor.name).get_shape() - - pruned_tensors = {k: v for k, v in pruned_tensors.items() if _is_pruned(v, graph)} - if self.on_debug(3, in_forward=False): - self._logger.debug( - msc_utils.msg_block(self.tool_mark("Pruned Tensors"), pruned_tensors) - ) - - if pruned_tensors: - pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors) - pruned_graphs.append(pruned_graph) - else: - pruned_graphs.append(graph) - - def _flatten_size(weights): - weight_size = sum([w.numpy().size for w in weights.values()]) - return weight_size / 2**20 - - raw_size = _flatten_size(weights) - # log compress rate - if pruned_cnt > 0: - new_size = _flatten_size(pruned_weights) - msg = f"Prune {pruned_cnt} weights, compress to {new_size * 100 / raw_size:.2f}% ({raw_size:.4f} M->{new_size:.4f} M)" - else: - msg = f"No weights pruned, size {raw_size:.4f} M" - self._logger.info(self.tool_mark(msg)) - return pruned_graphs, pruned_weights - - def get_meta_data(self, name: str) -> np.ndarray: - """Get meta weight as np.ndarray - - Parameters - ---------- - name: str - The name of data. - - Returns - ------- - data: np.ndarray - The data in np.ndarray format. - """ - - if name in self._meta_weights: - return msc_utils.cast_array(self._meta_weights[name]) - raise Exception(f"Can not find data {name} from {len(self._meta_weights)} weights") - - def create_tasks(self, **kwargs) -> List[dict]: - """Create tasks for gym - - Parameters - ---------- - kwargs: dict - The kwargs for create tasks. - - Returns - ------- - tasks: list - The tasks. - """ - - tasks = [] - for w_node in self.get_w_nodes(): - if w_node.get_attr("weight_strategy") != "main": - continue - consumer = self.find_producer(w_node.name).name - executor = self._get_tensor_strategy(w_node.name, consumer).get_executor(MSCStage.PRUNE) - tasks.append( - {"methods": {"tensor": executor.method_def}, "tensor_names": [w_node.name]} - ) - return tasks - - def change_strategys(self, strategy_list: List[dict]): - """Change the strategys - - Parameters - ------- - strategy_list: list - The given strategys. - """ - - self._plan = {} - self.change_stage(MSCStage.PRUNE) - super().change_strategys(strategy_list) - - def finalize(self) -> dict: - """Get the plan""" - - self._plan = {n: c for n, c in self._plan.items() if c["in_indices"] or c["out_indices"]} - return super().finalize() - - @property - def pruned(self): - return len(self._plan) > 0 - - @classmethod - def tool_type(cls): - return ToolType.PRUNER - - @classmethod - def exportable(cls): - return False - - -@msc_utils.register_tool -class DefaultPruner(BasePruner): - @classmethod - def tool_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/quantize/__init__.py b/python/tvm/contrib/msc/core/tools/quantize/__init__.py deleted file mode 100644 index e81be53cb867..000000000000 --- a/python/tvm/contrib/msc/core/tools/quantize/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.quantize""" - -from .quantizer import * -from .method import * -from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/quantize/configer.py b/python/tvm/contrib/msc/core/tools/quantize/configer.py deleted file mode 100644 index a4da7d198ba6..000000000000 --- a/python/tvm/contrib/msc/core/tools/quantize/configer.py +++ /dev/null @@ -1,126 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.quantize.configer""" - -from typing import Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.configer import ToolConfiger -from tvm.contrib.msc.core.tools.tool import ToolType - -from .quantizer import QuantizeStage - - -class QuantizeConfiger(ToolConfiger): - """Configer for quantize""" - - def config_gym(self, gym_config: Union[dict, str]) -> dict: - """Config the gym - - Parameters - ---------- - gym_config: dict - The raw config. - - Returns - ------- - gym_config: dict - The update config. - """ - - if isinstance(gym_config, dict): - return gym_config - if gym_config == "default": - return { - "env": { - "executors": { - "action_space": { - "method": "action_quantize_scale", - "start": 0.8, - "end": 1.2, - "step": 0.1, - } - }, - }, - "agent": {"agent_type": "search.grid", "executors": {}}, - } - else: - raise TypeError("Unexpected gym config " + str(gym_config)) - - @classmethod - def tool_type(cls): - return ToolType.QUANTIZER - - -@msc_utils.register_tool_configer -class DefaultQuantizeConfiger(QuantizeConfiger): - """Default configer for quantize""" - - def config_tool(self) -> dict: - """Get the default config of tool - - Returns - ------- - config: dict - The default config. - """ - - op_types = [ - "nn.conv1d", - "msc.conv1d_bias", - "nn.conv2d", - "msc.conv2d_bias", - "nn.conv3d", - "msc.conv3d_bias", - "msc.linear", - "msc.linear_bias", - "nn.avg_pool1d", - "nn.avg_pool2d", - "nn.avg_pool3d", - ] - - return { - "plan_file": "msc_quantizer.json", - "strategys": [ - { - "methods": { - "input": "gather_maxmin", - "output": "gather_maxmin", - "weights": "gather_max_per_channel", - }, - "op_types": op_types, - "stages": [QuantizeStage.GATHER], - }, - { - "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"}, - "op_types": op_types, - "stages": [QuantizeStage.CALIBRATE], - }, - { - "methods": { - "input": "quantize_normal", - "weights": "quantize_normal", - "output": "dequantize_normal", - }, - "op_types": op_types, - }, - ], - } - - @classmethod - def config_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/quantize/method.py b/python/tvm/contrib/msc/core/tools/quantize/method.py deleted file mode 100644 index 5cca88b824ff..000000000000 --- a/python/tvm/contrib/msc/core/tools/quantize/method.py +++ /dev/null @@ -1,477 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -# ruff: noqa: RUF005 -"""tvm.contrib.msc.core.tools.quantize.method""" - -from typing import Any, Union - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class QuantizeMethod: - """Default quantize method""" - - @classmethod - def amplify_data( - cls, data: np.array, scale: float, min_val: float, max_val: float, rounding: str = "round" - ) -> np.ndarray: - """Amplify the data - - Parameters - ---------- - data: np.ndarray - The source data. - scale: float - The scale factor - min_val: float - The min. - max_val: float - The max. - rounding: str - The round method - - Returns - ------- - data: np.ndarray - The processed data. - """ - - if rounding == "null": - return np.clip(data * scale, min_val, max_val) - if rounding == "floor": - return np.clip(np.floor(data * scale), min_val, max_val) - if rounding == "ceil": - return np.clip(np.ceil(data * scale), min_val, max_val) - if rounding == "round": - return np.clip(np.round(data * scale), min_val, max_val) - if rounding == "trunc": - return np.clip(np.trunc(data * scale), min_val, max_val) - if rounding == "logic_round": - data = np.clip(data * scale, min_val, max_val) - negative_ceil = np.where( - np.logical_and(data < 0, (data - np.floor(data)) == 0.5), np.ceil(data), 0 - ) - data = np.where(np.logical_and(data < 0, (data - np.floor(data)) == 0.5), 0, data) - data = np.where((data - np.floor(data)) >= 0.5, np.ceil(data), data) - data = np.where((data - np.floor(data)) < 0.5, np.floor(data), data) - return data + negative_ceil - raise TypeError("Unexpected rounding " + str(rounding)) - - @classmethod - def get_scale_tensor( - cls, - data: Any, - scale: float, - axis: int = -1, - epsilon: float = 1.0 / (1 << 24), - expand_dims: bool = True, - ) -> Union[float, np.ndarray]: - """Get the scale tensor - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: array_like - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - axis: int - The axis. - epsilon: float - The epsilon for get scale. - expand_dims: bool - Whether to expand dims - - Returns - ------- - scale_tensor: np.ndarray - The processed tensor. - """ - - data = msc_utils.cast_array(data) - if isinstance(scale, list): - scale_tensor = np.array(scale).astype(data.dtype) - if expand_dims: - scale_shape = [s if idx == axis else 1 for idx, s in enumerate(data.shape)] - scale_tensor = scale_tensor.reshape(scale_shape) - if scale_tensor.min() <= epsilon: - scale_mask = scale_tensor <= epsilon - scale_tensor[scale_mask] = 0 - elif scale <= epsilon: - scale_tensor = 0 - else: - scale_tensor = scale - return scale_tensor - - @classmethod - def gather_maxmin( - cls, - quantizer: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - ) -> dict: - """Gather the data by max/min - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - abs_max_list = plan.get("abs_max_list", []) - abs_max_list.append(float(np.abs(data).max())) - max_list = plan.get("max_list", []) - max_list.append(float(data.max())) - min_list = plan.get("min_list", []) - min_list.append(float(data.min())) - return { - "abs_max_list": abs_max_list, - "max_list": max_list, - "min_list": min_list, - "calibrated": False, - } - - @classmethod - def gather_kl_divergence( - cls, - quantizer: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - bins: int = 4096, - ) -> dict: - """Gather the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - bins: int - The number bins. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - if not plan or "abs_max" not in plan: - return cls.gather_maxmin(quantizer, name, data, plan, nbits) - hist, edge = np.histogram(data, bins=bins, range=[-plan["abs_max"], plan["abs_max"]]) - hist_list = plan.get("hist_list", []) - return {"hist_list": hist_list + [hist], "edge": edge, **plan} - - @classmethod - def gather_max_per_channel( - cls, - quantizer: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - channel: str = "O", - auto_unsign: bool = False, - ) -> dict: - """Gather the data by max_per_channel - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - channel: str - The channel reference. - auto_unsign: bool - Whether to use auto unsign. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - weight = quantizer.find_tensor(name) - axis = weight.layout_of(channel) - channel_datas = np.split(data, data.shape[axis], axis) - channel_max = [float(np.abs(d).max()) for d in channel_datas] - sign = data.min() < 0 if auto_unsign else True - valid_range = 2 ** (nbits - int(sign)) - 1 - scale = [valid_range / m for m in channel_max] - return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True} - - @classmethod - def calibrate_maxmin( - cls, - quantizer: BaseTool, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - auto_unsign: bool = False, - ) -> dict: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - auto_unsign: bool - Whether to use auto unsign. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - sign = plan["min"] < 0 if auto_unsign else True - valid_range = 2 ** (nbits - int(sign)) - 1 - abs_max = float(np.array(plan["abs_max_list"]).max()) - return {"scale": valid_range / abs_max, "sign": sign, "calibrated": True} - - @classmethod - def calibrate_kl_divergence( - cls, - quantizer: BaseTool, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - bins: int = 4096, - auto_unsign: bool = False, - ) -> dict: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - bins: int - The number bins. - auto_unsign: bool - Whether to use auto unsign. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - # pylint: disable=import-outside-toplevel - import ctypes - - from tvm.relay import quantize as _quantize - - if plan and "abs_max_list" in plan: - return { - "abs_max": float(np.array(plan["abs_max_list"]).max()), - "max": float(np.array(plan["max_list"]).max()), - "min": float(np.array(plan["min_list"]).min()), - "calibrated": False, - } - - def get_pointer(arr, ctypes_type): - ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type)) - return ctypes.cast(ptr, ctypes.c_void_p) - - sign = plan["min"] < 0 if auto_unsign else True - hist = np.array(plan["hist_list"]).sum(axis=0) - hist_ptr = get_pointer(hist.astype(np.int64), ctypes.c_int64) - edge_ptr = get_pointer(plan["edge"].astype(np.float32), ctypes.c_float) - valid_range = 2 ** (nbits - int(sign)) - 1 - scale = _quantize._quantize.FindScaleByKLMinimization(hist_ptr, edge_ptr, bins, valid_range) - return {"scale": valid_range / scale, "sign": sign, "calibrated": True} - - @classmethod - def quantize_normal( - cls, - quantizer: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - scale: float, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> np.ndarray: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - data: array like - The processed tensor. - """ - - valid_range = 2 ** (nbits - int(sign)) - 1 - min_val = -valid_range if sign else 0 - scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") - if scale_tensor is None: - scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon) - quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) - data = cls.amplify_data(data, scale_tensor, min_val, valid_range, rounding) - return data / scale - - @classmethod - def dequantize_normal( - cls, - quantizer: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - scale: float = -1.0, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> np.ndarray: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - data: array like - The processed tensor. - """ - - return data - - @classmethod - def framework(cls): - return MSCFramework.MSC - - @classmethod - def tool_type(cls): - return ToolType.QUANTIZER - - @classmethod - def method_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py deleted file mode 100644 index 8453eb3a6778..000000000000 --- a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py +++ /dev/null @@ -1,260 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.quantize.quantizer""" - -from typing import Any, Dict, List - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolStrategy, ToolType -from tvm.contrib.msc.core.utils.message import MSCStage - - -class QuantizeStage: - GATHER = "gather" - CALIBRATE = "calibrate" - - -class BaseQuantizer(BaseTool): - """Base quantizer for all""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - if self._plan: - self._calibrated = True - self.change_stage(MSCStage.QUANTIZE) - else: - self._calibrated = False - self._calibrate_plan = {} - self.change_stage(QuantizeStage.GATHER) - return super().setup() - - def calibrate(self) -> dict: - """Calibrate the datas - - Returns - ------- - plan: dict - The calibrated plan. - """ - - new_plan = {} - self.change_stage(QuantizeStage.CALIBRATE) - for tensor_id, plan in self._calibrate_plan.items(): - if plan.get("calibrated", False): - new_plan[tensor_id] = plan - continue - name, consumer = self.from_tensor_id(tensor_id) - strategy = self._get_tensor_strategy(name, consumer) - new_plan[tensor_id] = strategy(self, name, consumer, plan) - if any(not plan.get("calibrated", False) for plan in new_plan.values()): - self._calibrate_plan = new_plan - self.change_stage(QuantizeStage.GATHER) - else: - self._calibrated = True - for name, plan in new_plan.items(): - self._plan[name] = {k: v for k, v in plan.items() if k not in ("calibrated")} - self.change_stage(MSCStage.QUANTIZE) - calib_type = "calibrate" if self._calibrated else "gather" - msg = f"{calib_type} {len(new_plan)} plan after {self._forward_cnt} batch" - self._logger.info(self.tool_mark(msg)) - self._forward_cnt = 0 - return new_plan - - def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: - """Parse the strategy to get valid strategy - - Parameters - ------- - strategy_list: list - The given strategys - - Returns - ------- - strategys: dict - The parsed strategy. - """ - - def _update_stages(strategy): - if "stages" not in strategy: - strategy["stages"] = [MSCStage.QUANTIZE] - return strategy - - return super()._parse_strategys([_update_stages(s) for s in strategy_list]) - - def _check_tensor(self, name: str, consumer: str) -> bool: - """Check if the tensor should be processed - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - if self._calibrated: - return self.to_tensor_id(name, consumer) in self._plan - strategys = self._get_tensor_strategys(name, consumer) - if not strategys: - return False - if any(s.get_config().get("nbits", 8) == -1 for s in strategys): - return False - return True - - def _process_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if not self._calibrated: - return self._gather_tensor(tensor, name, consumer, strategys) - return self._quantize_tensor(tensor, name, consumer, strategys) - - def _gather_tensor( - self, tensor: Any, name: str, consumer: str, strategys: List[ToolStrategy] - ) -> Any: - """Gather tensor datas - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - assert len(strategys) == 1, "gather should only has 1 strategy, get " + str(strategys) - tensor_id = self.to_tensor_id(name, consumer) - plan = self._calibrate_plan.get(tensor_id, {}) - if plan.get("calibrated", False): - return tensor - self._calibrate_plan[tensor_id] = strategys[0](self, tensor, name, consumer, plan) - return tensor - - def _quantize_tensor( - self, tensor: Any, name: str, consumer: str, strategys: List[ToolStrategy] - ) -> Any: - """Quantize tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - tensor_id = self.to_tensor_id(name, consumer) - for strategy in strategys: - tensor = strategy(self, tensor, name, consumer, **self._plan[tensor_id]) - return tensor - - def create_tasks(self, **kwargs) -> List[dict]: - """Create tasks for gym - - Parameters - ---------- - kwargs: dict - The kwargs for create tasks. - - Returns - ------- - tasks: list - The tasks. - """ - - tasks, recorded = [], set() - for tensor_id in self._plan: - name, consumer = self.from_tensor_id(tensor_id) - if self.is_weight(name) and not kwargs.get("quantize_weights", False): - continue - if name not in recorded: - executor = self._get_tensor_strategy(name, consumer).get_executor(MSCStage.QUANTIZE) - task = {"methods": {"tensor": executor.method_def}} - if self._cache_processed: - task["tensor_ids"] = [ - self.to_tensor_id(name, c.name) for c in self.find_consumers(name) - ] - recorded.add(name) - else: - task["tensor_ids"] = [tensor_id] - tasks.append(task) - return tasks - - @property - def calibrated(self): - return self._calibrated - - @classmethod - def tool_type(cls): - return ToolType.QUANTIZER - - -@msc_utils.register_tool -class DefaultQuantizer(BaseQuantizer): - @classmethod - def tool_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py deleted file mode 100644 index 0901a1a0dc36..000000000000 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ /dev/null @@ -1,1624 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -# ruff: noqa: RUF012 -"""tvm.contrib.msc.core.tools.base_tool""" - -import copy -import logging -import os -from itertools import product -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph, MSCJoint, MSCTensor, WeightGraph, WeightJoint -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class ToolType: - """Enum all msc tool types""" - - BASE = "base" - WEIGHT = "weight" - PRUNER = "pruner" - QUANTIZER = "quantizer" - DISTILLER = "distiller" - TRACKER = "tracker" - ALL = [PRUNER, QUANTIZER, DISTILLER, TRACKER] - - @classmethod - def all_types(cls) -> List[str]: - return cls.ALL - - -class ToolScope: - """Enum all msc tool scope""" - - TEACHER = "teacher" - STUDENT = "student" - - -class ToolExecutor: - """Executor for process the tensor - - Parameters - ---------- - name: str - The name. - method: str - The method for execute. - config: dict - The config for execute - """ - - def __init__(self, name: str, method: callable, config: Optional[dict] = None): - self._name = name - self._method = method - self._config = config or {} - - def __str__(self): - return f"{self._name}({self._config})" - - def execute(self, *args, **kwargs) -> Any: - """execute the method - - Parameters - ---------- - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - plan or tensor: - The plan generated by method or processed tensor. - """ - - kwargs.update(self._config) - return self._method(*args, **kwargs) - - def copy( - self, - name: Optional[str] = None, - method: Optional[callable] = None, - config: Optional[dict] = None, - ): - """Copy a executor - - Parameters - ---------- - name: str - The name for new executor. - method: str - The method for new execute. - config: dict - The config for new execute - - Returns - ------- - new_executor: ToolExecutor - The copied executor - """ - - new_config = config or {} - new_config.update({k: v for k, v in self._config.items() if k not in new_config}) - return ToolExecutor(name or self._name, method or self._method, new_config) - - @property - def method_def(self): - return {"method_name": self._name, **self._config} - - @property - def name(self): - return self._name - - @property - def config(self): - return self._config - - -class ToolStrategy: - """Strategy for process tensor - - Parameters - ---------- - name: str - The name. - tensor_type: str - The tensor type. - stage: str - The init stage - meta: dict: - The meta strategy config. - """ - - def __init__(self, name: str, tensor_type: str, stage: str = "default"): - self._name = name - self._tensor_type = tensor_type - self._stage = stage - self._executors = {} - - def __str__(self): - return f"{self._name}({self._tensor_type} @ {self._stage}) " + "; ".join( - [f"{k}:{v}" for k, v in self._executors.items()] - ) - - def inspect(self) -> dict: - """Get inspect of strategy - - Returns - ------- - inspect: dict - The inspect of the strategy. - """ - - return {s: str(e) for s, e in self._executors.items()} - - def __call__(self, *args, **kwargs) -> Any: - return self.apply(*args, **kwargs) - - def apply(self, *args, **kwargs) -> Any: - """Apply the strategy - - Parameters - ---------- - args: list - The arguments for run method. - kwargs: dict - The key word arguments for run method. - - Returns - ------- - plan or tensor: - The plan generated by method or processed tensor. - """ - - return self.get_executor().execute(*args, **kwargs) - - def change_stage(self, stage: str): - """Change the stage of strategy""" - - self._stage = stage - - def add_executor(self, stage: str, executor: ToolExecutor): - """Add a executor to strategy - - Parameters - ---------- - stage: str - The mark of the executor. - executor: ToolExecutor - The executor to process tensor. - """ - - self._executors[stage] = executor - if not self._stage: - self._stage = stage - - def get_executor(self, stage: Optional[str] = None) -> Tuple[callable, dict]: - """Get executor of current stage - - Parameters - ---------- - stage: str - The mark of the executor. - - Returns - ------- - executor: tuple - The method and config to execute strategy - """ - - stage = stage or self._stage - if stage in self._executors: - return self._executors[stage] - return self._executors["default"] - - def get_config(self) -> dict: - """Get the config of current executor""" - - return self.get_executor().config - - def support_stage(self, stage: str) -> bool: - """Check if the strategy support a stage - - Parameters - ---------- - stage: str - The mark of the executor - - Returns - ------- - support: bool - Whether the strategy support the strategy - """ - - return stage in self._executors or "default" in self._executors - - def copy( - self, - name: Optional[str] = None, - tensor_type: Optional[str] = None, - stage: Optional[str] = None, - configs: Optional[Dict[str, dict]] = None, - ): - """Copy a strategy - - Parameters - ---------- - name: str - The name for new strategy - tensor_type: - The tensor type for new strategy - stage: str - The init stage for new strategy - configs: dict - The method config of new executors. - - Returns - ------- - new_strategy: ToolStrategy - The copied strategy - """ - - configs = configs or {} - strategy = ToolStrategy( - name or self._name, tensor_type or self._tensor_type, stage or self._stage - ) - for st_name, executor in self._executors.items(): - new_executor = executor.copy(config=configs.get(st_name, {})) - strategy.add_executor(st_name, new_executor) - return strategy - - -class BaseTool: - """Basic tool of MSC - - Parameters - ---------- - tag: str - The tag of tool. - stage: str - The stage of tool. - plan_file: str - The plan file path. - strategys: list[dict] - The strategys of the tool. - training: bool - Whether the tool is training. - cache_processed: bool - Whether to cache processed tensor. - options: dict - The extra options for the tool - debug_level: int - The debug level. - verbose_step: int - The verbose interval step. - logger: logging.Logger - The logger - """ - - def __init__( - self, - tag: str, - stage: str, - plan_file: str, - strategys: List[dict], - training: bool = False, - cache_processed: bool = True, - options: Optional[dict] = None, - debug_level: int = 0, - verbose_step: int = 50, - logger: Optional[logging.Logger] = None, - ): - self._tag = tag - self._stage = stage - self._plan_file = plan_file - if os.path.isfile(plan_file): - self._plan = msc_utils.load_dict(plan_file) - else: - self._plan = {} - self._meta_strategys, self._strategys = msc_utils.copy_dict(strategys), {} - self._training = training - self._cache_processed = cache_processed - self._options = options or {} - self._debug_level = debug_level - self._verbose_step = verbose_step - self._logger = logger or msc_utils.get_global_logger() - title = self.tool_mark("APPLY_PLAN" if self._plan else "MAKE_PLAN") - self._logger.info(msc_utils.msg_block(title, self.setup())) - - def __str__(self): - msg = ( - f"forward[{self._forward_cnt}] {len(self._graphs)} graphs, {len(self._weights)} weights" - ) - return self.tool_mark(msg) - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - self._tensor_cache = {} - self._enabled = True - self._graphs, self._weights = [], {} - self._graph_id, self._forward_cnt = 0, 0 - self._processed_tensor = {} - plan_info = self._plan if self._plan and self._debug_level >= 2 else self._plan_file - return { - "style": self.tool_style(), - "cache_processed": self._cache_processed, - "options": self._options, - f"debug_step({self._debug_level})": self._verbose_step, - f"plan({len(self._plan)})": plan_info, - } - - def reset( - self, - graphs: List[MSCGraph], - weights: Dict[str, tvm.runtime.Tensor], - cache_dir: msc_utils.MSCDirectory = None, - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool with graphs and weights - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - cache_dir: MSCDirectory - cache path for save/load info. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - self._forward_cnt = 0 - self._tensor_cache = {} - if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")): - cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json")) - else: - cache_info = {} - if self.tool_type() in cache_info: - self.load_cache(cache_dir, cache_info[self.tool_type()]) - self._graphs, self._weights = self._reset(graphs, weights) - self._strategys = self._parse_strategys(self._meta_strategys) - if self._strategys: - title = self.tool_mark(f"STRATEGYS({len(self._strategys)})") - strategys_info = {k: v.inspect() for k, v in self._strategys.items()} - self._logger.info(msc_utils.msg_block(title, strategys_info, width=0)) - return self._graphs, self._weights - - def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - return graphs, weights - - def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]: - """Parse the strategy to get valid strategy - - Parameters - ------- - strategy_list: list - The given strategys. - - Returns - ------- - strategys: dict - The parsed strategy. - """ - - assert isinstance(strategy_list, list) and all( - isinstance(s, dict) for s in strategy_list - ), "ToolStrategy should be given as list of dict" - assert self._graphs, "graphs are needed to parse strategys" - all_tensor_names = set(t.name for t in self.get_tensors()) - all_tensor_ids = set(self.get_tensor_ids()) - all_op_types = set(n.optype for n in self.get_nodes()) - all_op_names = set(n.name for n in self.get_nodes()) - strategys = {} - - def _get_method(method_name): - if "." in method_name: - method_cls_name, method_name = method_name.split(".") - else: - method_cls_name = "default" - method_cls = msc_utils.get_registered_tool_method( - self.framework(), self.tool_type(), method_cls_name - ) - if hasattr(method_cls, method_name): - return getattr(method_cls, method_name) - default_cls = msc_utils.get_registered_tool_method( - MSCFramework.MSC, self.tool_type(), method_cls_name - ) - if hasattr(default_cls, method_name): - return getattr(default_cls, method_name) - method = msc_utils.get_registered_func(method_name) - assert method, "Can not find method with " + str(method_name) - return method - - for strategy in strategy_list: - meta_strategy = msc_utils.copy_dict(strategy) - for t_type, method_def in meta_strategy["methods"].items(): - if isinstance(method_def, str): - method_name, method_kwargs = method_def, {} - elif isinstance(method_def, dict): - assert "method_name" in method_def, "Can not find method_name" - method_name = method_def["method_name"] - method_kwargs = {k: v for k, v in method_def.items() if k != "method_name"} - else: - raise TypeError( - "Only support string and dict as method define, get " + str(method_def) - ) - method = _get_method(method_name) - if "marks" in strategy: - assert t_type == "mark", "mark strategy only support mark method, get " + str( - meta_strategy - ) - marks = strategy["marks"] - elif "tensor_names" in strategy: - assert t_type == "tensor", ( - "tensor strategy only support tensor method, get " + str(meta_strategy) - ) - marks = [t for t in strategy["tensor_names"] if t in all_tensor_names] - elif "tensor_ids" in strategy: - assert t_type == "tensor", ( - "tensor strategy only support tensor method, get " + str(meta_strategy) - ) - marks = [t for t in strategy["tensor_ids"] if t in all_tensor_ids] - elif "op_types" in strategy: - op_types = [t for t in strategy["op_types"] if t in all_op_types] - marks = [f"{t}.{t_type}" for t in op_types] - elif "op_names" in strategy: - op_names = [t for t in strategy["op_names"] if t in all_op_names] - marks = [f"{t}.{t_type}" for t in op_names] - else: - marks = ["default." + str(t_type)] - for mark, stage in product(marks, strategy.get("stages", ["default"])): - if mark not in strategys: - strategys[mark] = ToolStrategy(mark, t_type, self._stage) - strategys[mark].add_executor( - stage, ToolExecutor(method_name, method, copy.deepcopy(method_kwargs)) - ) - return strategys - - def change_strategys(self, strategy_list: List[dict]): - """Change the strategys - - Parameters - ------- - strategy_list: list - The given strategys. - """ - - self._meta_strategys = strategy_list - - def change_stage(self, stage: str): - """Change the stage of tool and strategy""" - - self._stage = stage - for strategy in self._strategys.values(): - strategy.change_stage(stage) - - def change_logger(self, logger: logging.Logger): - """Change the logger of tool""" - - self._logger = logger - - def destory(self): - """Destory tool""" - - self._graphs, self._weights = [], {} - - def export_config(self, config: dict, folder: msc_utils.MSCDirectory) -> dict: - """Export the config for tool - - Parameters - ------- - config: dict - The source config. - folder: MSCDirectory - The export folder. - - Returns - ------- - config: dict - The exported config. - """ - - plan_file = msc_utils.to_abs_path(config["plan_file"], msc_utils.get_config_dir()) - if os.path.isfile(plan_file): - return {"plan_file": folder.create_dir("tools").copy(plan_file)} - return {} - - def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache_info - """ - - return None - - def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache_info. - """ - - return {} - - def execute_before_build(self, *args, **kwargs): - """Execute before model build - - Parameters - ---------- - args: list - The arguments for model build. - kwargs: dict - The key word arguments for model build. - """ - - if self._enabled: - self._graph_id = self._infer_graph_id(kwargs) - self._processed_tensor = {} - if self.on_debug(3, in_forward=False): - self._logger.debug(self.msg_mark("Start Build", in_forward=False)) - self._execute_before_build(*args, **kwargs) - - def _execute_before_build(self, *args, **kwargs): - """Execute before model build - - Parameters - ---------- - args: list - The arguments for model build. - kwargs: dict - The key word arguments for model build. - """ - - return None - - def execute_after_build(self, output: Any) -> Any: - """Execute after model build - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - if self._enabled: - output = self._execute_after_build(output) - if self.on_debug(3, in_forward=False): - self._logger.debug(self.msg_mark("End Build", in_forward=False)) - return output - - def _execute_after_build(self, output: Any) -> Any: - """Execute after model build - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - return output - - def execute_before_forward(self, *args, **kwargs): - """Execute before model forward - - Parameters - ---------- - args: list - The arguments for model forward. - kwargs: dict - The key word arguments for model forward. - """ - - if self._enabled: - self._graph_id = self._infer_graph_id(kwargs) - self._processed_tensor = {} - if self.on_debug(3): - self._logger.debug(self.msg_mark("Start Forward")) - self._execute_before_forward(*args, **kwargs) - - def _execute_before_forward(self, *args, **kwargs): - """Execute before model forward - - Parameters - ---------- - args: list - The arguments for model forward. - kwargs: dict - The key word arguments for model forward. - """ - - return None - - def execute_after_forward(self, output: Any) -> Any: - """Execute after model forward - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - if self._enabled: - output = self._execute_after_forward(output) - if self.on_debug(3): - msg = f"End Forward, process {len(self._processed_tensor)} tensors" - self._logger.debug(self.msg_mark(msg)) - self._forward_cnt += 1 - return output - - def _execute_after_forward(self, output: Any) -> Any: - """Execute after model forward - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - return output - - def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if not self._enabled: - return tensor - if not self._support_scope(scope): - return tensor - strategys = self._get_tensor_strategys(name, consumer) - t_mark = ".".join([s.get_executor().name for s in strategys]) - if scope: - t_mark += "." + scope - cached_tensor = self._get_processed(name, consumer, t_mark) - if cached_tensor is not None: - self.debug_tensors(name, consumer, t_mark, {"cached": cached_tensor}) - return cached_tensor - process = self._get_tensor_cache(name, consumer, "process") - if process is None: - process = self._check_tensor(name, consumer) - self._save_tensor_cache(name, consumer, "process", process) - if not process: - return tensor - if isinstance(tensor, dict): - new_tensor = self._process_tensor( - msc_utils.copy_dict(tensor), name, consumer, scope, strategys - ) - else: - new_tensor = self._process_tensor(tensor, name, consumer, scope, strategys) - self._save_processed(name, consumer, new_tensor, t_mark) - if msc_utils.is_array(tensor) and id(new_tensor) != id(tensor): - tensors = {"org": tensor, "new": new_tensor, "dif": tensor - new_tensor} - self.debug_tensors(name, consumer, t_mark, tensors) - elif isinstance(tensor, dict) and len(tensor.get("processed", [])) != len( - new_tensor.get("processed", []) - ): - tensors = {"org": tensor, "new": new_tensor} - self.debug_tensors(name, consumer, t_mark, tensors) - return new_tensor - - def _support_scope(self, scope: str) -> bool: - """Check if the scope si supported - - Parameters - ------- - scope: str - The scope mark, should be null or ToolScope - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - if not scope: - return True - return scope != ToolScope.TEACHER - - def _get_processed(self, name: str, consumer: str, strategy_mark: str) -> Any: - """Get cached processed tensor - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - strategy_mark: str - The sstrategy mark. - - Returns - ------- - processed_tensor - The cached processed tensor. - """ - - if self._cache_processed: - return self._processed_tensor.get(name + "." + strategy_mark) - return None - - def _save_processed(self, name: str, consumer: str, tensor: Any, strategy_mark: str): - """Save cached processed tensor - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - tensor: Any - The processed tensor - strategy_mark: str - The sstrategy mark. - """ - - if self._cache_processed: - self._processed_tensor[name + "." + strategy_mark] = tensor - else: - self._processed_tensor[self.to_tensor_id(name, consumer)] = None - - def _check_tensor(self, name: str, consumer: str) -> bool: - """Check if the tensor should be processed - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - strategys = self._get_tensor_strategys(name, consumer) - return len(strategys) > 0 - - def _process_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - return tensor - - def create_tasks(self, **kwargs) -> List[dict]: - """Create tasks for gym - - Parameters - ---------- - kwargs: dict - The kwargs for create tasks. - - Returns - ------- - tasks: list - The tasks. - """ - - return [] - - def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: - """Update the generate configs - - Parameters - ---------- - generate_config: dict - The generate_config. - - Returns - ------- - generate_config: dict - The updated generate_config. - """ - - return generate_config - - def visualize(self, visual_dir: msc_utils.MSCDirectory): - """Visualize MSCGraphs - - Parameters - ------- - visual_dir: MSCDirectory - Visualize path for saving graph - """ - - return None - - def finalize(self) -> dict: - """Get the plan""" - - return self._plan - - def enable(self): - """Enable the tool""" - - self._enabled = True - - def disable(self): - """Disable the tool""" - - self._enabled = False - - def train(self): - """Set the tool to train mode""" - - self._training = True - - def eval(self): - """Set the tool to eval mode""" - - self._training = False - - def to_tensor_id(self, name: str, consumer: str) -> str: - """Concat name to unique id - - Parameters - ---------- - name: str - The name of tensor. - consumer: str - The name of consumer. - - Returns - ------- - tensor_id: str - The unique name of edge. - """ - - return f"{name}-c-{consumer}" - - def from_tensor_id(self, tensor_id: str) -> Tuple[str]: - """Split name from unique id - - Parameters - ---------- - tensor_id: str - The unique name of edge. - - Returns - ------- - name: str - The name of tensor. - consumer: str - The name of consumer. - """ - - return tensor_id.split("-c-") - - def is_weight(self, name: str) -> bool: - """Check if the tensor is weight - - Parameters - ---------- - name: str - The name of tensor. - - Returns - ------- - is_weight: bool - Whether the name is weight. - """ - - return name in self._weights - - def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool: - """Check if should log - - Parameters - ------- - debug_level: int - The given debug_level. - in_forward: bool - Whether to check forward_cnt. - - Returns - ------- - on_debug: bool - Whether to log debug info. - """ - - if in_forward and self._forward_cnt % self._verbose_step != 0: - return False - return self._debug_level >= debug_level - - def tool_mark(self, msg: Any) -> str: - """Mark the message with tool info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"{self.tool_type().upper()}[{self._tag}]({self.framework()} @ {self._stage}) {msg}" - - def msg_mark(self, msg: Any, in_forward: bool = True) -> str: - """Mark the message with debug info - - Parameters - ------- - msg: - The message - in_forward: bool - Whether to add forward mark. - - Returns - ------- - msg: str - The message with mark. - """ - - mark = f"{self.tool_type().upper()}({self._tag} @ {self._stage}) G[{self._graph_id}]" - if in_forward: - mark += f".F[{self._forward_cnt}]" - return mark + " " + str(msg) - - def debug_tensors( - self, name: str, consumer: str, t_mark: str, tensors: Dict[str, Any], debug_level: int = 3 - ) -> str: - """Get the debug tensor info - - Parameters - ------- - name: str - The name of tensor. - consumer: str - The name of consumer. - t_mark: str - The mark of tensor. - tensors: dict - The tensors. - debug_level: int - The given debug_level. - """ - - if self.on_debug(debug_level): - - def _t_info(tensor): - if msc_utils.is_array(tensor): - return msc_utils.inspect_array(tensor) - if isinstance(tensor, dict) and "processed" in tensor: - return "{}({} processed)".format( - self.find_tensor(name), len(tensor["processed"]) - ) - return str(tensor) - - msg = f"{name}-{consumer}({t_mark})" - tensor_des = "\n ".join([f"{k:6s}:{_t_info(v)}" for k, v in tensors.items()]) - self._logger.debug("%s\n %s", self.msg_mark(msg), tensor_des) - - def _infer_graph_id(self, kwargs: dict) -> int: - """Infer graph id from kwargs - - Parameters - ---------- - kwargs: dict - The kwargs for execute. - """ - - if "graph_id" in kwargs: - return kwargs.pop("graph_id") - if "graph_name" in kwargs: - name = kwargs.pop("graph_name") - for idx, g in enumerate(self._graphs): - if g.name == name: - return idx - return 0 - - def get_nodes(self) -> Iterable[MSCJoint]: - """Get all the nodes in the graphs. - - Returns - ------- - nodes: generator - The generator of nodes. - """ - - for g in self._graphs: - yield from g.get_nodes() - - def find_node(self, name: str) -> MSCJoint: - """Find node by name. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - node: MSCJoint - The found node. - """ - - for g in self._graphs: - if g.has_node(name): - return g.find_node(name) - raise Exception(f"Can not find node {name} from {len(self._graphs)} graphs") - - def get_tensors(self) -> Iterable[MSCTensor]: - """Get all the tensors in the graphs. - - Returns - ------- - tensors: generator - The generator of tensors. - """ - - for graph in self._graphs: - yield from graph.get_tensors() - - def get_tensor_ids(self) -> Iterable[MSCTensor]: - """Get all the tensor ids in the graphs. - - Returns - ------- - tensors: generator - The generator of tensor ids. - """ - - for graph in self._graphs: - for node in graph.get_nodes(): - for tensor in node.get_inputs(): - yield self.to_tensor_id(tensor.name, node.name) - for weight in node.get_weights().values(): - yield self.to_tensor_id(weight.name, node.name) - - def find_tensor(self, t_ref: Union[str, MSCTensor]) -> MSCTensor: - """Find tensor by tensor ref. - - Parameters - ---------- - t_ref: string| MSCTensor - The name of the tensor or tensor. - - Returns - ------- - node: MSCTensor - The found tensor. - """ - - t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref - for g in self._graphs: - if g.has_tensor(t_name): - return g.find_tensor(t_name) - raise Exception(f"Can not find tensor {t_name} from {len(self._graphs)} graphs") - - def find_producer(self, t_ref: Union[str, MSCTensor]) -> MSCJoint: - """Find producer by tensor ref. - - Parameters - ---------- - t_ref: string| MSCTensor - The name of the tensor or tensor. - - Returns - ------- - node: MSCJoint - The found prducer. - """ - - t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref - for g in self._graphs: - if g.has_tensor(t_name): - return g.find_producer(t_name) - raise Exception(f"Can not find producer of {t_name} from {len(self._graphs)} graphs") - - def find_consumers(self, t_ref: Union[str, MSCTensor]) -> List[MSCJoint]: - """Find consumers by tensor ref. - - Parameters - ---------- - t_ref: string| MSCTensor - The name of the tensor or tensor. - - Returns - ------- - node: list - The found consumers. - """ - - t_name = t_ref.name if isinstance(t_ref, MSCTensor) else t_ref - for g in self._graphs: - if g.has_tensor(t_name): - return g.find_consumers(t_name) - raise Exception(f"Can not find consumers of {t_name} from {len(self._graphs)} graphs") - - def get_data(self, name: str) -> np.ndarray: - """Get the data by name - - Parameters - ------- - name: str - The tensor name - - Returns - ------- - data: np.ndarray - The data. - """ - - if name in self._weights: - return msc_utils.cast_array(self._weights[name]) - raise Exception(f"Can not find data {name} from {len(self._weights)} weights") - - def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any) -> Any: - """Save the data to tensor cache - - Parameters - ------- - name: str - The tensor name. - consumer: str - The name of the consumer. - key: str - The data key. - value: any - The value to cache. - - Returns - ------- - value: any - The saved value. - """ - - tensor_id = self.to_tensor_id(name, consumer) - if tensor_id not in self._tensor_cache: - self._tensor_cache[tensor_id] = {} - self._tensor_cache[tensor_id][key] = value - return value - - def _get_tensor_cache(self, name: str, consumer: str, key: str) -> Any: - """Get the cached tensor data - - Parameters - ------- - name: str - The tensor name. - consumer: str - The name of the consumer. - key: str - The data key. - - Returns - ------- - value: any - The cached value. - """ - - tensor_id = self.to_tensor_id(name, consumer) - if tensor_id not in self._tensor_cache: - return None - return self._tensor_cache[tensor_id].get(key) - - def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]: - """Get the strategys by name and consumer - - Parameters - ------- - name: str - The tensor name. - consumer: str - The name of the consumer. - - Returns - ------- - strategys: list - The strategys for the tensor. - """ - - tensor_id = self.to_tensor_id(name, consumer) - mark = f"strategy.{self._stage}" - if mark not in self._tensor_cache.get(tensor_id, {}): - strategys = [] - - def _add_strategy(ref): - if ref in self._strategys and self._strategys[ref].support_stage(self._stage): - strategys.append(self._strategys[ref]) - return True - return False - - tensor_strategy = self._strategys.get(tensor_id) or self._strategys.get(name) - if tensor_strategy and tensor_strategy.support_stage(self._stage): - strategys.append(tensor_strategy) - elif self.is_weight(name): - consumer = self.find_node(consumer) - for w_type in [consumer.weight_type(name), "weights"]: - for ref in [consumer.name, consumer.optype, "default"]: - if not strategys and _add_strategy(ref + "." + w_type): - break - elif consumer == "exit": - producer = self.find_producer(name) - for ref in [producer.name, producer.optype, "exit", "default"]: - if _add_strategy(ref + ".output"): - break - else: - producer = self.find_producer(name) - for ref in [producer.name, producer.optype, "default"]: - if _add_strategy(ref + ".output"): - break - consumer = self.find_node(consumer) - for ref in [consumer.name, consumer.optype, "default"]: - if _add_strategy(ref + ".input"): - break - self._save_tensor_cache(name, consumer, mark, strategys) - return self._get_tensor_cache(name, consumer, mark) - - def _get_tensor_strategy(self, name: str, consumer: str) -> ToolStrategy: - """Get the unique strategy by name and consumer - - Parameters - ------- - name: str - The tensor name. - consumer: str - The name of the consumer. - - Returns - ------- - strategy: ToolStrategy - The unique strategy for the tensor. - """ - - strategys = self._get_tensor_strategys(name, consumer) - if not strategys: - return None - assert len(strategys) == 1, f"{self._stage} should only has 1 strategy, get {strategys}" - return strategys[0] - - def get_graph(self): - return self._graphs[self._graph_id] - - @property - def plan(self): - return self._plan - - @classmethod - def tool_type(cls): - return ToolType.BASE - - @classmethod - def framework(cls): - return MSCFramework.MSC - - @classmethod - def tool_style(cls): - return "base" - - @classmethod - def apply_once(cls): - return False - - @classmethod - def exportable(cls): - return True - - -class WeightTool(BaseTool): - """Basic tool with weight graphs""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - self._weight_graphs = [] - return super().setup() - - def _reset( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - """ - - graphs, weights = super()._reset(graphs, weights) - self._main_wtypes, self._relation_wtypes = self._get_wtypes() - assert self._main_wtypes, "main_wtypes should be given to build weight graphs" - if self._weight_graphs: - assert len(graphs) == len(self._weight_graphs), ( - f"Graphs {len(graphs)} mismatch with weight graphs {len(self._weight_graphs)}" - ) - else: - self._weight_graphs = [ - _ffi_api.WeightGraph(graph, self._main_wtypes, self._relation_wtypes) - for graph in graphs - ] - msg = f"build {len(self._weight_graphs)} weight graphs" - self._logger.debug(self.tool_mark(msg)) - if self.on_debug(2, in_forward=False): - weight_graphs = {g.name: g.inspect() for g in self._weight_graphs} - title = self.tool_mark(f"WEIGHT_GRAPHS({len(weight_graphs)})") - self._logger.debug(msc_utils.msg_block(title, weight_graphs)) - return graphs, weights - - def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]: - """Get the weight types from options - - Returns - ------- - main_wtypes: dict> - The main weight types. - relation_wtypes: dict - The relation weight types - """ - - raise NotImplementedError("_get_wtypes is not implemented in WeightTool") - - def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict): - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - cache_info: dict - The cache_info - """ - - assert "weight_graphs" in cache_info, ( - "weight_graphs should be given in cache_info, get " + str(cache_info) - ) - self._weight_graphs = [ - WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"] - ] - msg = f"load {len(self._weight_graphs)} weight graphs from {cache_dir}" - self._logger.debug(self.tool_mark(msg)) - - def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict: - """Save runner to cache - - Parameters - ------- - cache_dir: MSCDirectory - cache path for save/load info - - Returns - ------- - cache_info: dict - The cache_info. - """ - - cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]} - with cache_dir: - for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]): - with open(f_path, "w") as f_graph: - f_graph.write(graph.to_json()) - return cache_info - - def visualize(self, visual_dir: msc_utils.MSCDirectory): - """Visualize MSCGraphs - - Parameters - ------- - visual_dir: MSCDirectory - Visualize path for saving graph - """ - - for w_graph in self._weight_graphs: - w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt")) - super().visualize(visual_dir) - - def get_w_nodes(self) -> Iterable[WeightJoint]: - """Get all the weight nodes in the weight_graphs. - - Returns - ------- - nodes: generator - The generator of weight nodes. - """ - - for g in self._weight_graphs: - yield from g.get_nodes() - - def has_w_node(self, name: str) -> bool: - """Check if name in weight_graphs. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - has_node: bool - Whether node in weight_graphs. - """ - - for g in self._weight_graphs: - if g.has_node(name): - return True - return False - - def find_w_node(self, name: str) -> WeightJoint: - """Find weight node by name. - - Parameters - ---------- - name: string - The name of the node. - - Returns - ------- - node: WeightJoint - The found node. - """ - - for g in self._weight_graphs: - if g.has_node(name): - return g.find_node(name) - raise Exception(f"Can not find node {name} from graphs") - - def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: - """Get the input output axes - - Parameters - ---------- - w_node: WeightJoint - The weight node. - - Returns - ------- - axes: (int, int) - The input output axis. - """ - - if w_node.weight.ndim == 1: - return 0, 0 - if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"): - return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis")) - in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") - if in_axis >= 0 and out_axis >= 0: - return in_axis, out_axis - if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0: - io_axis = 1 - w_node.weight.layout_of("N") - return io_axis, io_axis - if w_node.weight.layout_of("C") >= 0: - return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") - raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) - - @classmethod - def tool_type(cls): - return ToolType.WEIGHT diff --git a/python/tvm/contrib/msc/core/tools/track/__init__.py b/python/tvm/contrib/msc/core/tools/track/__init__.py deleted file mode 100644 index da4f97731226..000000000000 --- a/python/tvm/contrib/msc/core/tools/track/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.track""" - -from .tracker import * -from .method import * -from .configer import * diff --git a/python/tvm/contrib/msc/core/tools/track/configer.py b/python/tvm/contrib/msc/core/tools/track/configer.py deleted file mode 100644 index 279e0a58f1c7..000000000000 --- a/python/tvm/contrib/msc/core/tools/track/configer.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.track.configer""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.configer import ToolConfiger -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils import MSCStage - - -class TrackConfiger(ToolConfiger): - """Configer for track""" - - @classmethod - def tool_type(cls): - return ToolType.TRACKER - - -@msc_utils.register_tool_configer -class DefaultTrackConfiger(TrackConfiger): - """Default configer for track""" - - def config_tool(self) -> dict: - """Get the default config of tool - - Returns - ------- - config: dict - The default config. - """ - - return { - "plan_file": "msc_tracker.json", - "strategys": [ - { - "methods": { - "output": { - "method_name": "save_compared", - "compare_to": { - MSCStage.OPTIMIZE: [MSCStage.BASELINE], - MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE], - }, - } - }, - "op_types": ["nn.relu"], - } - ], - } - - @classmethod - def config_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/track/method.py b/python/tvm/contrib/msc/core/tools/track/method.py deleted file mode 100644 index 1076151ca443..000000000000 --- a/python/tvm/contrib/msc/core/tools/track/method.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.tools.track.method""" - -from typing import Dict, List - -import numpy as np - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class TrackMethod: - """Default track method""" - - @classmethod - def save_compared( - cls, - tracker: BaseTool, - data: np.ndarray, - name: str, - consumer: str, - compare_to: Dict[str, List[str]], - ) -> np.ndarray: - """Compare and save the data - - Parameters - ---------- - tracker: BaseTracker - The tracker - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - compare_to: dict - The compare config - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - data = msc_utils.cast_array(data) - config = {"info": msc_utils.inspect_array(data)} - # save the data - tracker._saver.save_datas({name: data}, tracker._forward_cnt) - tracker.debug_tensors(name, consumer, "save_compares", {"save": data}) - # compare datas - if tracker._stage in compare_to: - diffs = {} - for stage in compare_to[tracker._stage]: - if stage in tracker._loaders: - if not tracker._loaders[stage].has_data(name, tracker._forward_cnt): - continue - golden = tracker._loaders[stage].load_data(name, tracker._forward_cnt) - report = msc_utils.compare_arrays({name: golden}, {name: data}) - diff_msg = "{} to {} -> {}".format(name, stage, report["info"][name]) - if report["passed"] == 0: - tracker._logger.info(tracker.msg_mark(diff_msg)) - elif tracker.on_debug(): - tracker._logger.debug(tracker.msg_mark(diff_msg)) - diffs[stage] = { - "pass": report["passed"] == 1, - "info": msc_utils.inspect_array(np.abs(golden - data)), - } - config["diffs"] = diffs - return config - - @classmethod - def framework(cls): - return MSCFramework.MSC - - @classmethod - def tool_type(cls): - return ToolType.TRACKER - - @classmethod - def method_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/tools/track/tracker.py b/python/tvm/contrib/msc/core/tools/track/tracker.py deleted file mode 100644 index 28a316558b6e..000000000000 --- a/python/tvm/contrib/msc/core/tools/track/tracker.py +++ /dev/null @@ -1,197 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.tools.track.tracker""" - -from typing import Any, List - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import BaseTool, ToolStrategy, ToolType - - -class BaseTracker(BaseTool): - """Base tracker for all""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - suffix = "." + msc_utils.MSCStage.TRACK - if self._stage.endswith(suffix): - self.change_stage(self._stage[: -len(suffix)]) - - data_folder = msc_utils.get_dataset_dir().create_dir("Track") - self._loaders = {} - for folder in data_folder.listdir(): - if folder == self._stage: - continue - if msc_utils.is_simple_dataset(data_folder.relpath(folder)): - self._loaders[folder] = msc_utils.SimpleDataLoader(data_folder.relpath(folder)) - self._saver = msc_utils.SimpleDataSaver(data_folder.relpath(self._stage)) - self._max_iter, self._tracked = self._options.get("max_iter", 1), False - info = super().setup() - info.update({"saver": self._saver, "loaders": self._loaders}) - return info - - def finalize(self) -> dict: - """Get the plan""" - - self._saver.finalize() - return {} - - def _execute_after_forward(self, output: Any) -> Any: - """Execute after model forward - - Parameters - ---------- - output: Any - The output reference of the model. - - Returns - ------- - output: Any - The modified output reference. - """ - - if self._forward_cnt < self._max_iter: - passed = {} - for info in self._plan.values(): - if "diffs" not in info[self._stage]: - continue - for stage, p_info in info[self._stage]["diffs"].items(): - if stage not in passed: - passed[stage] = {"total": 0, "passed": 0} - passed[stage]["total"] += 1 - if p_info["pass"]: - passed[stage]["passed"] += 1 - msg = f"Track({self._stage})[{self._forward_cnt}] {len(self._plan)} datas" - if passed: - msg += ", passed -> " - msg += "; ".join( - ["{}: {}/{}".format(s, i["passed"], i["total"]) for s, i in passed.items()] - ) - self._logger.info(self.msg_mark(msg, in_forward=False)) - else: - self._tracked = True - return output - - def _check_tensor(self, name: str, consumer: str) -> bool: - """Check if the tensor should be processed - - Parameters - ------- - name: str - The name of the tensor. - consumer: str - The name of the consumer. - - Returns - ------- - vaild: bool - Whether to process the tensor. - """ - - if self._forward_cnt >= self._max_iter: - return False - strategy = self._get_tensor_strategy(name, consumer) - if not strategy: - return False - compare_to = strategy.get_config().get("compare_to", {}) - if self._stage in compare_to: - return True - for stages in compare_to.values(): - if self._stage in stages: - return True - return False - - def _process_tensor( - self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - return self._track_tensor(tensor, name, consumer, strategys) - - def _track_tensor( - self, tensor: Any, name: str, consumer: str, strategys: List[ToolStrategy] - ) -> Any: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if self._stage in self._plan.get(name, {}): - return tensor - plan = self._plan.setdefault(name, {}).setdefault(self._stage, {}) - for strategy in strategys: - plan.update(strategy(self, tensor, name, consumer)) - return tensor - - @property - def tracked(self): - return self._tracked - - @classmethod - def tool_type(cls): - return ToolType.TRACKER - - @classmethod - def apply_once(cls): - return True - - -@msc_utils.register_tool -class DefaultTracker(BaseTracker): - @classmethod - def tool_style(cls): - return "default" diff --git a/python/tvm/contrib/msc/core/transform/__init__.py b/python/tvm/contrib/msc/core/transform/__init__.py deleted file mode 100644 index ec7459780359..000000000000 --- a/python/tvm/contrib/msc/core/transform/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.transform""" - -from .pattern import * -from .transform import * diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py deleted file mode 100644 index 07b5cfd88fbf..000000000000 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ /dev/null @@ -1,620 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.transform.pattern""" - -from functools import partial -from typing import Dict, List, Mapping, Optional, Tuple - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCKey, MSCMap -from tvm.relax.backend.pattern_registry import register_patterns -from tvm.relax.dpl import pattern as relax_pattern -from tvm.relax.transform import PatternCheckContext - - -def msc_attrs_getter( - annotated_expr: Dict[str, tvm.relax.Expr], - anchor: str = "out", - output: Optional[str] = None, - inputs: Optional[List[str]] = None, -) -> Dict[str, str]: - """Get attributes for fused pattern - - Parameters - ---------- - annotated_expr: dict - The annotated exprs during fus pattern - anchor: str - The anchor key of expr - - Returns - ------- - attrs: dict - The extra attributes for msc. - """ - - attrs = {} - # get name - fused_cnt = MSCMap.get(MSCKey.FUSED_CNT, 0) - unique_name = "msc_fused_" + str(fused_cnt) - if anchor in annotated_expr: - name = msc_utils.get_expr_name(annotated_expr[anchor]) - if name: - unique_name = name - MSCMap.set(MSCKey.FUSED_CNT, fused_cnt + 1) - attrs[_ffi_api.ToAttrKey("unique")] = unique_name - # get output layout - output = output or anchor - if output in annotated_expr: - attrs[_ffi_api.ToAttrKey("layout")] = msc_utils.get_expr_layout(annotated_expr[output]) - if inputs: - layouts = {} - for i in inputs: - if i not in annotated_expr: - continue - in_name = msc_utils.get_expr_name(annotated_expr[i]) - if not in_name: - continue - layouts[in_name] = msc_utils.get_expr_layout(annotated_expr[i]) - if layouts: - attrs[_ffi_api.ToAttrKey("input_layouts")] = layouts - return attrs - - -def make_relax_conv_bias_pattern( - op_name: str, -) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]: - """A simple utility to create patterns for an conv fused with bias. - - Parameters - ---------- - op_name: str - The name of a Relax op, such as "relax.nn.conv2d" - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a conv_bias operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = relax_pattern.wildcard() - weight = relax_pattern.is_const() - conv = relax_pattern.is_op(op_name)(data, weight) - bias = relax_pattern.is_const() - shape = relax_pattern.wildcard() - reshape = relax_pattern.is_op("relax.reshape")(bias, shape) - out = relax_pattern.is_op("relax.add")(conv, reshape) - annotations = { - "data": data, - "weight": weight, - "conv": conv, - "bias": bias, - "reshape": reshape, - "out": out, - } - return out, annotations - - -def _check_relax_conv_bias(context: PatternCheckContext) -> bool: - """Check if conv_bias fuse pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - bias = context.annotated_expr["bias"] - reshape = context.annotated_expr["reshape"] - non_one_dims = len([i for i in reshape.struct_info.shape.values if i > 1]) - return non_one_dims <= 1 and bias.struct_info.ndim == 1 - - -def make_relax_linear_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for linear. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a linear operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = relax_pattern.wildcard() - weight = relax_pattern.is_const() - permute = relax_pattern.is_op("relax.permute_dims")(weight) - out = relax_pattern.is_op("relax.matmul")(data, permute) - annotations = {"data": data, "weight": weight, "permute": permute, "linear": out} - return out, annotations - - -def _check_relax_linear(context: PatternCheckContext) -> bool: - """Check if linear pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - weight = context.annotated_expr["weight"] - permute = context.annotated_expr["permute"] - return weight.struct_info.ndim == 2 and not permute.attrs["axes"] - - -def make_relax_linear_bias_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for linear with bias. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a linear_bias operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - - """ - - linear, annotations = make_relax_linear_pattern() - bias = relax_pattern.is_const() - out = relax_pattern.is_op("relax.add")(linear, bias) - annotations.update({"bias": bias, "bias_add": out}) - return out, annotations - - -def _check_relax_linear_bias(context: PatternCheckContext) -> bool: - """Check if linear_bias pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if not _check_relax_linear(context): - return False - bias = context.annotated_expr["bias"] - return bias.struct_info.ndim == 1 - - -def make_relax_embedding_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for embedding. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a embedding operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - weight = relax_pattern.is_const() - data = relax_pattern.wildcard() - astype = relax_pattern.is_op("relax.astype")(data) - out = relax_pattern.is_op("relax.take")(weight, astype) - annotations = {"data": data, "weight": weight, "astype": astype, "take": out} - return out, annotations - - -def _check_relax_embedding(context: PatternCheckContext) -> bool: - """Check if 1d embedding pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - weight = context.annotated_expr["weight"] - astype = context.annotated_expr["astype"] - return ( - astype.attrs["dtype"] == "int32" - and weight.struct_info.ndim == 2 - and weight.struct_info.dtype == "float32" - ) - - -def make_relax_reshape_embedding_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for reshaped embedding. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a reshape_embedding operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - weight = relax_pattern.is_const() - data = relax_pattern.wildcard() - astype = relax_pattern.is_op("relax.astype")(data) - reduce_shape = relax_pattern.wildcard() - reduce_in = relax_pattern.is_op("relax.reshape")(astype, reduce_shape) - take = relax_pattern.is_op("relax.take")(weight, reduce_in) - expand_shape = relax_pattern.wildcard() - out = relax_pattern.is_op("relax.reshape")(take, expand_shape) - annotations = { - "data": data, - "weight": weight, - "astype": astype, - "reduce_in": reduce_in, - "take": take, - "out": out, - } - return out, annotations - - -def _check_relax_reshape_embedding(context: PatternCheckContext) -> bool: - """Check if reshape embedding pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - weight = context.annotated_expr["weight"] - if weight.struct_info.ndim != 2 or weight.struct_info.dtype != "float32": - return False - astype = context.annotated_expr["astype"] - reduce_in = context.annotated_expr["reduce_in"] - if astype.attrs["dtype"] != "int32" or reduce_in.struct_info.ndim != 1: - return False - return True - - -def make_relax_attention_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for attention. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a attention operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - weight_q = relax_pattern.wildcard() - weight_k = relax_pattern.wildcard() - weight_v = relax_pattern.wildcard() - q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) - k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) - v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - attention = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) - out = relax_pattern.is_op("relax.permute_dims")(attention) - annotations = { - "weight_q": weight_q, - "weight_k": weight_k, - "weight_v": weight_v, - "q_trans": q_trans, - "k_trans": k_trans, - "v_trans": v_trans, - "attention": attention, - "out": out, - } - return out, annotations - - -def _check_relax_attention(context: PatternCheckContext) -> bool: - """Check if attention pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - return True - - -def make_relax_mask_attention_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """A simple utility to create patterns for mask_attention. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a mask_attention operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - weight_q = relax_pattern.wildcard() - weight_k = relax_pattern.wildcard() - weight_v = relax_pattern.wildcard() - mask = relax_pattern.wildcard() - q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) - k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) - v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - attention = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) - out = relax_pattern.is_op("relax.permute_dims")(attention) - annotations = { - "weight_q": weight_q, - "weight_k": weight_k, - "weight_v": weight_v, - "mask": mask, - "q_trans": q_trans, - "k_trans": k_trans, - "v_trans": v_trans, - "attention": attention, - "out": out, - } - return out, annotations - - -def _check_relax_mask_attention(context: PatternCheckContext) -> bool: - """Check if mask_attention pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - return True - - -def make_opt_relax_conv_bias_pattern( - op_name: str, -) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]: - """Create patterns for an conv2d fused with bias, for mod after optimize. - - Parameters - ---------- - op_name: str - The name of a Relax op, such as "relax.nn.conv2d" - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a conv_bias operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = relax_pattern.wildcard() - weight = relax_pattern.is_const() - conv = relax_pattern.is_op(op_name)(data, weight) - bias = relax_pattern.is_const() - out = relax_pattern.is_op("relax.add")(conv, bias) - annotations = {"data": data, "weight": weight, "bias": bias, "conv": conv, "out": out} - return out, annotations - - -def _check_opt_relax_conv_bias(context: PatternCheckContext) -> bool: - """Check if conv_bias fuse pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - ndim_conv = len(context.annotated_expr["conv"].struct_info.shape.values) - ndim_bias = len(context.annotated_expr["bias"].struct_info.shape.values) - ndim_out = len(context.annotated_expr["out"].struct_info.shape.values) - return ndim_conv == ndim_bias and ndim_bias == ndim_out - - -def make_opt_relax_linear_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """Create patterns for an linear, for mod after optimize. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a conv_bias operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = relax_pattern.wildcard() - weight = relax_pattern.is_const() - out = relax_pattern.is_op("relax.matmul")(data, weight) - annotations = {"data": data, "weight": weight, "linear": out} - return out, annotations - - -def _check_opt_relax_linear(context: PatternCheckContext) -> bool: - """Check if linear fuse pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - ndim_weight = len(context.annotated_expr["weight"].struct_info.shape.values) - return ndim_weight == 2 - - -def make_opt_relax_linear_bias_pattern() -> Tuple[ - relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] -]: - """Create patterns for an linear_bias, for mod after optimize. - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing a conv_bias operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = relax_pattern.wildcard() - weight = relax_pattern.is_const() - linear = relax_pattern.is_op("relax.matmul")(data, weight) - bias = relax_pattern.is_const() - out = relax_pattern.is_op("relax.add")(linear, bias) - annotations = {"data": data, "weight": weight, "bias": bias, "linear": linear, "out": out} - return out, annotations - - -def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: - """Check if linear fuse pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if not _check_opt_relax_linear(context): - return False - ndim_bias = len(context.annotated_expr["bias"].struct_info.shape.values) - ndim_out = len(context.annotated_expr["out"].struct_info.shape.values) - return ndim_bias == 1 or ndim_bias == ndim_out - - -# TODO(tong.meng): support patterns after optimize -register_patterns( - [ - ( - "msc.conv1d_bias", - *make_opt_relax_conv_bias_pattern( - "relax.nn.conv1d", - ), - _check_opt_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), - ), - ( - "msc.conv2d_bias", - *make_opt_relax_conv_bias_pattern( - "relax.nn.conv2d", - ), - _check_opt_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), - ), - ( - "msc.linear", - *make_opt_relax_linear_pattern(), - _check_opt_relax_linear, - partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight"]), - ), - ( - "msc.linear_bias", - *make_opt_relax_linear_bias_pattern(), - _check_opt_relax_linear_bias, - partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight", "bias"]), - ), - ( - "msc.conv1d_bias", - *make_relax_conv_bias_pattern( - "relax.nn.conv1d", - ), - _check_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), - ), - ( - "msc.conv2d_bias", - *make_relax_conv_bias_pattern( - "relax.nn.conv2d", - ), - _check_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), - ), - ( - "msc.linear", - *make_relax_linear_pattern(), - _check_relax_linear, - partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight"]), - ), - ( - "msc.linear_bias", - *make_relax_linear_bias_pattern(), - _check_relax_linear_bias, - partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight", "bias"]), - ), - ( - "msc.embedding", - *make_relax_embedding_pattern(), - _check_relax_embedding, - partial(msc_attrs_getter, anchor="take", inputs=["data", "weight"]), - ), - ( - "msc.embedding", - *make_relax_reshape_embedding_pattern(), - _check_relax_reshape_embedding, - partial(msc_attrs_getter, anchor="take", output="out", inputs=["data", "weight"]), - ), - ( - "msc.attention", - *make_relax_attention_pattern(), - _check_relax_attention, - partial( - msc_attrs_getter, anchor="attention", inputs=["weight_q", "weight_k", "weight_v"] - ), - ), - ( - "msc.attention", - *make_relax_mask_attention_pattern(), - _check_relax_mask_attention, - partial( - msc_attrs_getter, - anchor="attention", - inputs=["weight_q", "weight_k", "weight_v", "mask"], - ), - ), - ] -) diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py deleted file mode 100644 index 2b8fb7fc571d..000000000000 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ /dev/null @@ -1,142 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""tvm.contrib.msc.core.transform.transform""" - -from typing import Dict, Optional - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.relax.transform import _ffi_api - - -def SetExprName( - entry_name: str = "main", - target: str = "", - var_names: Optional[Dict[str, str]] = None, -) -> tvm.ir.transform.Pass: - """Set name for the call and constant in IRModule. - - Parameters - ---------- - entry_name: str - The entry name - target: str - The target prefix for target functions - var_names: dict - The var names. - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - var_names = var_names or {} - var_names = {k: msc_utils.legalize_expr_name(v) for k, v in var_names.items()} - return _ffi_api.SetRelaxExprName(entry_name, target, var_names) # type: ignore - - -def SetExprLayout(allow_missing: bool = True, entry_name: str = "main") -> tvm.ir.transform.Pass: - """Set layout for the var and constant in IRModule. - - Parameters - ---------- - allow_missing: bool - Whether allow missing layouts. - entry_name: str - The entry name - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return _ffi_api.SetExprLayout(allow_missing, entry_name) # type: ignore - - -def InlineParams(entry_name: str = "main") -> tvm.ir.transform.Pass: - """Bind ShapeExpr to reshape - - Parameters - ---------- - entry_name: str - The entry name - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return _ffi_api.InlineParams(entry_name) # type: ignore - - -def FuseTuple(target, entry_name: str = "main") -> tvm.ir.transform.Pass: - """Fuse Tuple and TupleGetItem to target - - Parameters - ---------- - target: str - The byoc target name - entry_name: str - The entry name - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return _ffi_api.FuseTuple(target, entry_name) # type: ignore - - -def SetBYOCAttrs(target, entry_name: str = "main") -> tvm.ir.transform.Pass: - """set attributes for byoc - - Parameters - ---------- - target: str - The byoc target name - entry_name: str - The entry name - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return _ffi_api.SetBYOCAttrs(target, entry_name) # type: ignore - - -def BindNamedParams( - func_name: str, - params: Dict[str, tvm.runtime.Tensor], -) -> tvm.ir.transform.Pass: - """Bind params of function of the module to constant tensors with span names. - - Parameters - ---------- - func_name: str - The function name to be bound - params: dict - The map from parameter or parameter name to constant - tensors. - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - return _ffi_api.BindNamedParams(func_name, params) # type: ignore diff --git a/python/tvm/contrib/msc/core/utils/__init__.py b/python/tvm/contrib/msc/core/utils/__init__.py deleted file mode 100644 index 641860e12e2a..000000000000 --- a/python/tvm/contrib/msc/core/utils/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.utils""" - -from .expr import * -from .info import * -from .file import * -from .namespace import * -from .register import * -from .dataset import * -from .log import * -from .message import * -from .arguments import * diff --git a/python/tvm/contrib/msc/core/utils/arguments.py b/python/tvm/contrib/msc/core/utils/arguments.py deleted file mode 100644 index 7fae33191a6c..000000000000 --- a/python/tvm/contrib/msc/core/utils/arguments.py +++ /dev/null @@ -1,263 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 -"""tvm.contrib.msc.core.utils.arguments""" - -import copy -import json -import os -from typing import Any - -from .info import MSCArray - - -def load_dict(str_dict: str, flavor: str = "json") -> dict: - """Load the string/file to dict. - - Parameters - ---------- - str_dict: string - The file_path or string object. - flavor: str - The flavor for load. - - Returns - ------- - dict_obj: dict - The loaded dict. - """ - - if not str_dict: - return {} - if isinstance(str_dict, str) and os.path.isfile(str_dict): - with open(str_dict) as f: - dict_obj = json.load(f) - elif isinstance(str_dict, str): - dict_obj = json.loads(str_dict) - elif isinstance(str_dict, dict): - dict_obj = copy_dict(str_dict) - else: - raise Exception(f"Unexpected str_dict {str_dict}({type(str_dict)})") - assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor) - return dict_obj - - -def save_dict(dict_obj: Any, path: str, indent: int = 2) -> str: - """Save dict object - - Parameters - ---------- - dict_obj: - The object that can be load as dict. - path: str - The output path. - indent: int - The indent - - Returns - ------- - path: str - The output path. - """ - - with open(path, "w") as f: - f.write(json.dumps(load_dict(dict_obj), indent=indent)) - return path - - -def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = False) -> dict: - """Update src_dict with new_dict. - - Parameters - ---------- - src_dict: dict - The source dict. - new_dict: dict - The new dict. - soft_update: bool - Whether to update the source dict, False to force update. - - Returns - ------- - dict_obj: dict - The updated dict. - """ - - if not new_dict: - return src_dict - assert isinstance(src_dict, dict) and isinstance(new_dict, dict), ( - f"update_dict only support dict, get src {type(src_dict)} and new {type(new_dict)}" - ) - for k, v in new_dict.items(): - if not src_dict.get(k): - src_dict[k] = v - elif isinstance(v, dict): - v = update_dict(src_dict.get(k, {}), v, soft_update) - src_dict[k] = v - elif not soft_update: - src_dict[k] = v - return src_dict - - -def dump_dict(dict_obj: dict, flavor: str = "dmlc") -> str: - """Dump the config to string. - - Parameters - ---------- - src_dict: dict - The source dict. - flavor: str - The flavor for dumps. - - Returns - ------- - str_dict: string - The dumped string. - """ - - if not dict_obj: - return "" - if flavor == "dmlc": - return json.dumps({k: int(v) if isinstance(v, bool) else v for k, v in dict_obj.items()}) - if flavor.startswith("table:"): - - def _get_lines(value, indent=2): - max_size = int(flavor.split(":")[1]) - indent - 2 - lines = [] - for k, v in value.items(): - if v is None: - continue - if isinstance(v, (dict, tuple, list)) and not v: - continue - if isinstance(v, dict) and len(str(k) + str(v)) > max_size: - lines.append("{}{}:".format(indent * " ", k)) - lines.extend(_get_lines(v, indent + 2)) - elif isinstance(v, (tuple, list)) and len(str(k) + str(v)) > max_size: - if MSCArray.is_array(v): - lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) - else: - lines.append("{}{}:".format(indent * " ", k)) - for idx, ele in enumerate(v): - if isinstance(ele, dict) and len(str(ele)) > max_size: - lines.append("{}[{}.{}]:".format((indent + 2) * " ", k, idx)) - lines.extend(_get_lines(ele, indent + 4)) - else: - lines.append("{}<{}>{}".format((indent + 2) * " ", idx, ele)) - elif isinstance(v, bool): - lines.append("{}{}: {}".format(indent * " ", k, "true" if v else "false")) - elif MSCArray.is_array(v): - lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) - elif hasattr(v, "__name__"): - lines.append("{}{}: {}({})".format(indent * " ", k, v.__name__, type(v))) - else: - lines.append("{}{}: {}".format(indent * " ", k, v)) - return lines - - lines = _get_lines(dict_obj) or [f" {k}: {v}" for k, v in dict_obj.items()] - return "\n".join(lines) - return json.dumps(dict_obj) - - -def dict_equal(dict_a: dict, dict_b: dict) -> bool: - """Check if two dicts are the same. - - Parameters - ---------- - dict_a: dict - The A dict. - dict_b: dict - The B dict. - - Returns - ------- - equal: bool - Whether two dicts are the same. - """ - - if not isinstance(dict_a, dict) or not isinstance(dict_b, dict): - return False - if dict_a.keys() != dict_b.keys(): - return False - for k, v in dict_a.items(): - if not isinstance(v, type(dict_b[k])): - return False - if isinstance(v, dict) and not dict_equal(v, dict_b[k]): - return False - if v != dict_b[k]: - return False - return True - - -def copy_dict(dict_obj: dict) -> dict: - """Deepcopy dict object - - Parameters - ---------- - dict_obj: dict - The source dict. - - Returns - ------- - dict_obj: dict - The copied dict. - """ - - if not dict_obj: - return {} - try: - return copy.deepcopy(dict_obj) - except: # pylint: disable=bare-except - new_dict = {} - for k, v in dict_obj.items(): - if isinstance(v, (list, tuple)): - new_dict[k] = [copy_dict(e) for e in v] - elif isinstance(v, dict): - new_dict[k] = copy_dict(v) - else: - new_dict[k] = v - return new_dict - - -def map_dict(dict_obj: dict, mapper: callable) -> dict: - """Apply mapper to dict object - - Parameters - ---------- - dict_obj: dict - The source dict. - mapper: callable - The mapper function. - - Returns - ------- - new_dict: dict - The mapped dict. - """ - - if not dict_obj: - return {} - new_dict = {} - for k, v in dict_obj.items(): - if isinstance(v, (tuple, list)): - new_dict[k] = [ - map_dict(mapper(e), mapper) if isinstance(e, dict) else mapper(e) for e in v - ] - elif isinstance(v, dict): - new_dict[k] = map_dict(mapper(v), mapper) - else: - new_dict[k] = mapper(v) - return new_dict diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py deleted file mode 100644 index 8bf5c856a255..000000000000 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ /dev/null @@ -1,609 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.core.utils.dataset""" - -import json -import os -import shutil -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np - -import tvm - -from .arguments import load_dict -from .info import cast_array, is_array -from .namespace import MSCFramework - - -def format_datas(datas: Union[List[Any], Dict[str, Any]], names: List[str], style="dict") -> Any: - """Format datas to style format - - Parameters - ---------- - datas: - The source datas. - names: list - The data names. - style: str - The style of format, dict|list. - - Returns - ------- - datas: - The formated datas. - """ - - if isinstance(datas, (list, tuple, tvm.ir.container.Array)): - assert len(datas) == len(names), f"datas({len(datas)}) mismatch with names {names}" - datas = dict(zip(names, datas)) - if not isinstance(datas, dict): - assert len(names) == 1, "Expect 1 names, get " + str(names) - datas = {names[0]: datas} - elif len(datas) > len(names): - datas = {n: datas[n] for n in datas} - assert all(is_array(d) for d in datas.values()), "Expected all tensors as array like" - if style == "dict": - return datas - if style == "list": - return [datas[n] for n in names] - raise TypeError("Unexpected style " + str(style)) - - -def random_data( - info: Union[List, Tuple, dict], - framework: str = MSCFramework.MSC, - device: str = "cpu", - max_val: Optional[int] = None, -) -> Any: - """Create random data from info - - Parameters - ---------- - info: list| tuple| dict - The data info. - framework: str - The framework. - device: str - The device. - """ - - if isinstance(info, (tuple, list)): - if len(info) == 1: - info = {"name": "data", "shape": info[0], "dtype": "float32"} - elif len(info) == 2: - info = {"name": "data", "shape": info[0], "dtype": info[1]} - elif len(info) == 3: - info = {"name": info[0], "shape": info[1], "dtype": info[2]} - else: - raise Exception("Unexpected info " + str(info)) - assert isinstance(info, dict) and all(key in info for key in ["shape", "dtype"]), ( - "shape and dtype should be given to create randome data" - ) - if info["dtype"] in ("int32", "int64"): - if max_val is None: - data = np.zeros(info["shape"]).astype(info["dtype"]) - else: - data = np.random.randint(0, high=max_val, size=info["shape"]).astype(info["dtype"]) - elif info["dtype"] == "bool": - data = np.random.rand(*info["shape"]).astype("float32") - data = np.where(data >= 0.5, True, False) - else: - data = np.random.rand(*info["shape"]).astype(info["dtype"]) - if max_val is not None: - data *= max_val - return cast_array(data, framework, device=device) - - -class BaseDataLoader: - """Basic dataset loader for MSC - - Parameters - ---------- - folder: string - The dataset folder path. - start: int - The start position. - end: int - The end position. - """ - - def __init__(self, folder: str, start: int = 0, end: int = -1): - self._folder = folder - self._start = start - self._current = 0 - assert os.path.isdir(folder), f"Dataset {folder} is not folder" - self._info = load_dict(os.path.join(folder, "datas_info.json")) - if end == -1: - self._end = self._info["num_datas"] - else: - self._end = min(end, self._info["num_datas"]) - - def __str__(self): - return f"<{self.__class__.__name__}> @ {self._folder}" - - def __getitem__(self, idx): - if idx + self._start >= self._end: - raise StopIteration("Reach End") - return self._load_batch(idx) - - def __next__(self): - if self._current + self._start >= self._end: - raise StopIteration("Reach End") - batch = self._load_batch(self._current) - self._current += 1 - return batch - - def __len__(self): - return self._end - self._start - - def reset(self): - self._current = 0 - - def has_data(self, name: str, index: int) -> bool: - """Check if data exist. - - Parameters - ------- - name: str - The name of the data. - index: int - The index of the data. - - Returns - ------- - has_data: bool - Whether the data can be load. - """ - - info = self._data_info(name) - if not info: - return False - save_name = info.get("save_name", name) - f_path = os.path.join(self._folder, save_name, f"batch_{self._start + index}.bin") - return os.path.isfile(f_path) - - def load_data(self, name: str, index: int) -> np.ndarray: - """Load data by name. - - Parameters - ------- - name: str - The name of the data. - index: int - The index of the data. - - Returns - ------- - data: np.ndarray - The loaded data. - """ - - return self._load_data(name, index, self._data_info(name)) - - def _load_data(self, name: str, index: int, info: dict) -> np.ndarray: - """Load data from file. - - Parameters - ------- - name: str - The name of the data. - index: int - The index of the data. - info: dict - The info of the data. - - Returns - ------- - data: np.ndarray - The loaded data. - """ - - save_name = info.get("save_name", name) - f_path = os.path.join(self._folder, save_name, f"batch_{self._start + index}.bin") - assert os.path.isfile(f_path), "Can not find data file " + str(f_path) - return np.fromfile(f_path, dtype=info["dtype"]).reshape(info["shape"]) - - def _load_batch(self, index: int) -> Any: - """Get batch data - - Parameters - ------- - index: int - The index for the batch. - - Returns - ------- - batch: Any - The batch data. - """ - - raise NotImplementedError("_load_batch is not implemented for BaseDataLoader") - - def _data_info(self, name: str) -> dict: - """Get info of data - - Parameters - ------- - name: str - The name of data. - - Returns - ------- - info: dict - The info of data. - """ - - raise NotImplementedError("_data_info is not implemented for BaseDataLoader") - - @property - def num_datas(self): - return self.info["num_datas"] - - @property - def folder(self): - return self._folder - - @property - def info(self): - return self._info - - -class SimpleDataLoader(BaseDataLoader): - """Dataset Loader for simple datas""" - - def _load_batch(self, index: int) -> Any: - """Get batch data - - Parameters - ------- - index: int - The index for the batch. - - Returns - ------- - batch: Any - The batch data. - """ - - assert "datas" in self._info, "datas shoule be given to load batch" - return {n: self._load_data(n, index, i) for n, i in self._info["datas"].items()} - - def _data_info(self, name: str) -> dict: - """Get info of data - - Parameters - ------- - name: str - The name of data. - - Returns - ------- - info: dict - The info of data. - """ - - return self._info["datas"].get(name) - - -class IODataLoader(BaseDataLoader): - """Dataset Loader for Input/Output datas""" - - def _load_batch(self, index: int) -> Any: - """Get batch data - - Parameters - ------- - index: int - The index for the batch. - - Returns - ------- - batch: Any - The batch data. - """ - - if "inputs" in self._info: - inputs = {n: self._load_data(n, index, i) for n, i in self._info["inputs"].items()} - else: - inputs = {} - if "outputs" in self._info: - outputs = {n: self._load_data(n, index, i) for n, i in self._info["outputs"].items()} - else: - outputs = {} - return inputs, outputs - - def _data_info(self, name: str) -> dict: - """Get info of data - - Parameters - ------- - name: str - The name of data. - - Returns - ------- - info: dict - The info of data. - """ - - if name in self._info["inputs"]: - return self._info["inputs"][name] - return self._info["outputs"].get(name) - - -class BaseDataSaver: - """Dataset Saver for MSC - - Parameters - ---------- - folder: string - The dataset folder path. - options: dict - The extra options for the data saver - start: int - The start position. - max_size: int - The max size for datas. - """ - - def __init__( - self, - folder: str, - options: Optional[dict] = None, - start: int = 0, - max_size: int = -1, - ): - if os.path.isdir(folder): - shutil.rmtree(folder) - os.mkdir(folder) - self._folder = folder - self._start = start - self._max_size = max_size - self._current = 0 - assert os.path.isdir(folder), f"Dataset {folder} is not folder" - self._info = self.setup(options) - - def setup(self, options: dict): - return {"num_datas": 0} - - def __str__(self): - return f"<{self.__class__.__name__}> @ {self._folder}" - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - self.finalize() - - def finalize(self): - """Finalize the saver""" - - self._info["num_datas"] = self._current - with open(os.path.join(self._folder, "datas_info.json"), "w") as f: - f.write(json.dumps(self._info, indent=2)) - - def is_finalized(self) -> bool: - """Check if the saver is finalized - - Returns - ------- - is_finalized: bool - Whether the saver is finalized. - """ - - return os.path.isfile(os.path.join(self._folder, "datas_info.json")) - - def reset(self): - self._current = 0 - - def _save_data(self, index: int, name: str, data: np.ndarray, collect: str) -> str: - """Save data to file. - - Parameters - ------- - index: int - The index - name: str - The name of the data. - data: np.ndarray - The data to be saved. - collect: str - The collect of data. - - Returns - ------- - data_path: str - The folder that data saved to. - """ - - data = cast_array(data) - save_name = name.replace("/", "_").replace(":", "_") - sub_folder = f_path = os.path.join(self._folder, save_name) - if not os.path.isdir(sub_folder): - os.mkdir(sub_folder) - f_path = os.path.join(sub_folder, f"batch_{self._start + index}.bin") - ref_info = self._info[collect] - # TODO(mengtong): support dynamic datas shape - if name in ref_info: - assert ref_info[name]["dtype"] == data.dtype.name, ( - "dtype {} mismatch with saved {}".format(data.dtype.name, ref_info[name]["dtype"]) - ) - assert ref_info[name]["shape"] == list(data.shape), ( - "shape {} mismatch with saved {}".format(data.shape, ref_info[name]["shape"]) - ) - else: - ref_info[name] = { - "shape": list(data.shape), - "dtype": data.dtype.name, - "bytes": data.size * data.itemsize, - "save_name": save_name, - } - data.tofile(f_path) - return sub_folder - - def _save_batch(self, *args, **kwargs) -> dict: - """Save a batch data""" - - raise NotImplementedError("_save_batch is not implemented for BaseDataSaver") - - @property - def num_datas(self): - if self.is_finalized(): - return self.info["num_datas"] - return self._current - - @property - def folder(self): - return self._folder - - @property - def info(self): - return self._info - - -class SimpleDataSaver(BaseDataSaver): - """Dataset Saver for simple datas""" - - def save_datas(self, datas: Dict[str, np.ndarray], index: int = -1) -> Dict[str, str]: - """Save 1 simple datas. - - Parameters - ------- - datas: dict - The datas to be saved. - indec: int - The current index - - Returns - ------- - datas_path: dict - The data paths. - """ - - datas_path = {} - current = self._current if index < 0 else index - for name, data in datas.items(): - datas_path[name] = self._save_data(current, name, data, "datas") - if index > 0: - self._current = index - else: - self._current += 1 - return datas_path - - def setup(self, options: dict): - return {"datas": {}, "num_datas": 0} - - -class IODataSaver(BaseDataSaver): - """Dataset Saver for inputs/outputs""" - - def setup(self, options: dict): - assert "input_names" in options, "input_names should be given to setup IODataSaver" - self._input_names = options["input_names"] - self._output_names = options.get("output_names", []) - return { - "inputs": {}, - "outputs": {}, - "num_datas": 0, - "input_names": self._input_names, - "output_names": self._output_names, - } - - def finalize(self): - """Finalize the saver""" - - super().finalize() - if any(n not in self._info["inputs"] for n in self._input_names): - return - with open(os.path.join(self._folder, "datas_info.txt"), "w") as f: - for name in self._input_names: - info = self._info["inputs"][name] - f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) - for name in self._output_names: - if name not in self._info["outputs"]: - continue - info = self._info["outputs"][name] - f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) - - def is_finalized(self) -> bool: - """Check if the saver is finalized - - Returns - ------- - is_finalized: bool - Whether the saver is finalized. - """ - - if not super().is_finalized(): - return False - return os.path.isfile(os.path.join(self._folder, "datas_info.txt")) - - def save_batch( - self, - inputs: Union[Dict[str, np.ndarray], List[np.ndarray]], - outputs: Optional[Union[Dict[str, np.ndarray], List[np.ndarray]]] = None, - ) -> int: - """Save 1 batch inputs and outputs. - - Parameters - ------- - inputs: list/dict - The inputs datas. - outputs: list/dict - The outputs datas. - - Returns - ------- - current: int - The current batch cnt. - """ - - inputs = format_datas(inputs, self._input_names, style="dict") - for name, data in inputs.items(): - self._save_data(self._current, name, data, "inputs") - if outputs is not None: - outputs = format_datas(outputs, self._output_names, style="dict") - for name, data in outputs.items(): - self._save_data(self._current, name, data, "outputs") - self._current += 1 - return self._current - - -def is_io_dataset(folder: str) -> bool: - """Check if a folder is IO dataset""" - - if not isinstance(folder, str): - return False - if not os.path.isfile(os.path.join(folder, "datas_info.json")): - return False - data_info = load_dict(os.path.join(folder, "datas_info.json")) - if any(key not in data_info for key in ["inputs", "outputs", "num_datas"]): - return False - return data_info["num_datas"] > 0 - - -def is_simple_dataset(folder: str) -> bool: - """Check if a folder is simple dataset""" - - if not os.path.isfile(os.path.join(folder, "datas_info.json")): - return False - data_info = load_dict(os.path.join(folder, "datas_info.json")) - if any(key not in data_info for key in ["datas", "num_datas"]): - return False - return data_info["num_datas"] > 0 diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py deleted file mode 100644 index e5aed4782e9b..000000000000 --- a/python/tvm/contrib/msc/core/utils/expr.py +++ /dev/null @@ -1,218 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.utils.expr""" - -import copy -from typing import Dict, List, Optional - -import tvm -from tvm import relax -from tvm.contrib.msc.core import _ffi_api -from tvm.relax import PyExprVisitor - - -def legalize_expr_name(name: str, symbols: Optional[List[str]] = None, dst: str = "_") -> str: - """Legalize expr name - - Parameters - ---------- - name: str - The source name. - symbols: list - The symbols to be replaced. - dst: str - The symbol for replace. - - Returns - ------- - name: str - The legialized name. - """ - - symbols = symbols or ["::", "/", "."] - for sym in symbols: - name = name.replace(sym, dst) - return name.strip(dst) - - -def get_expr_name(expr: relax.Expr) -> str: - """Get name hint for expr - - Parameters - ---------- - expr: Expr - The Expr of relax. - - Returns - ------- - name: str - The name_hint of expr - """ - - name = _ffi_api.SpanGetAttr(expr.span, _ffi_api.ToAttrKey("name")) - if not name and isinstance(expr, relax.Var): - return expr.name_hint - return name - - -def make_span(kwargs: Dict[str, str], span: relax.Span = None) -> relax.Span: - """Make a span from kwargs - - Parameters - ---------- - kwargs: dict - The attrs in span. - span: relax.Span - The source span. - - Returns - ------- - span: relax.Span - The span. - """ - - span = span or relax.Span(tvm.ir.SourceName(""), 0, 0, 0, 0) - for k, v in kwargs.items(): - span = _ffi_api.SpanSetAttr(span, _ffi_api.ToAttrKey(k), v) - return span - - -def set_expr_name(expr: relax.Expr, name: str): - """Set the name for expr - - Parameters - ---------- - expr: Expr - The Expr of relax. - name: str - The name. - - Returns - ------- - expr: Expr - The expr with name. - """ - - expr.span = make_span({"name": name}, expr.span) - return expr - - -def get_expr_layout(expr: relax.Expr) -> str: - """Get layout for expr - - Parameters - ---------- - expr: Expr - The Expr of relax. - - Returns - ------- - layout: str - The layout of expr - """ - - return _ffi_api.SpanGetAttr(expr.span, _ffi_api.ToAttrKey("layout")) - - -def get_span_attrs(mod: tvm.IRModule) -> dict: - """Extract the span attributes from relax.Function. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - - Returns - ------- - attrs: dict - """ - - @relax.expr_functor.visitor - class SpanVisitor(PyExprVisitor): - """Visitor for get attributes in span""" - - def extract(self, expr: relax.Expr) -> dict: - self._span_info = {} - self._local_funcs = {} - if isinstance(expr, relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, relax.BindingBlock): - self.visit_binding_block(expr) - return self._span_info - - def _update_attrs(self, expr: relax.Expr, name: str = "") -> None: - if not expr.span: - return - name = name or get_expr_name(expr) - if not name: - return - self._span_info[name] = dict(_ffi_api.SpanGetAttrs(expr.span)) - - def visit_var_binding_(self, binding: relax.VarBinding) -> None: - if isinstance(binding.value, relax.expr.Function): - self._local_funcs[binding.var] = binding.value - elif ( - isinstance(binding.value, relax.expr.Call) and binding.value.op in self._local_funcs - ): - cache_info = copy.deepcopy(self._span_info) - func_info = self.extract(self._local_funcs[binding.value.op]) - self._span_info = cache_info - self._span_info[binding.value.op.name_hint] = func_info - else: - super().visit_var_binding_(binding) - self._update_attrs(binding.value, binding.var.name_hint) - - def visit_constant_(self, op: relax.Constant) -> None: - super().visit_constant_(op) - self._update_attrs(op) - - def visit_var_(self, op: relax.Var) -> None: - super().visit_var_(op) - self._update_attrs(op, op.name_hint) - - return {v.name_hint: SpanVisitor().extract(mod[v]) for v in mod.functions} - - -def msc_script(mod: tvm.IRModule, script: str = "") -> str: - """Add span attrs after lines. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - script: string - The script to be replaced - - Returns - ------- - script: string - The replaced script - """ - - script = script or str(mod) - attrs = get_span_attrs(mod) - cur_attr, lines = {}, [] - for line in script.split("\n"): - if line.strip().startswith("def "): - func_name = line.strip().split("def ")[1].split("(")[0] - cur_attr = attrs.get(func_name, {}) - if ": " in line: - v_name = line.strip().split(": ")[0] - if v_name in cur_attr: - line += " # " + ", ".join([f"{k}={v}" for k, v in cur_attr[v_name].items()]) + " #" - lines.append(line) - return "\n".join(lines) diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py deleted file mode 100644 index 7323f076d367..000000000000 --- a/python/tvm/contrib/msc/core/utils/file.py +++ /dev/null @@ -1,536 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.utils.file""" - -import os -import shutil -import subprocess -import tempfile -import types -from functools import partial -from importlib.machinery import SourceFileLoader -from typing import Any, List, Optional, Union - -from .namespace import MSCFramework, MSCKey, MSCMap -from .register import get_registered_func - - -def is_callable(name: str, framework: str = MSCFramework.MSC) -> bool: - """Check if name is callable. - - Parameters - ---------- - name: string - The name of the registered func or path:f_name str. - framework: string - Should be from MSCFramework. - - Returns - ------- - is_callable: bool - Whether the name is callable - """ - - func = get_registered_func(name, framework) - if func: - return True - if ".py:" in name: - path, _ = name.split(":") - return os.path.isfile(path) - return False - - -def load_callable(name: str, framework: str = MSCFramework.MSC) -> callable: - """Load a callable object. - - Parameters - ---------- - name: string - The name of the registered func or path:f_name str. - framework: string - Should be from MSCFramework. - - Returns - ------- - func: callable - The function. - """ - - func = get_registered_func(name, framework) - if func: - return func - if ".py:" in name: - path, func_name = name.split(":") - loader = SourceFileLoader(path.replace(".py", ""), path) - mod = types.ModuleType(loader.name) - loader.exec_module(mod) - return getattr(mod, func_name) - raise Exception("Func {} is neighter registered nor path.py:name string") - - -class MSCDirectory: - """Create a directory manager for MSC""" - - def __init__( - self, path: Optional[str] = None, keep_history: bool = True, cleanup: bool = False - ): - self._path = os.path.abspath(path or tempfile.mkdtemp()) - self._cleanup = cleanup - self._cwd = os.getcwd() - if os.path.isdir(self._path) and not keep_history: - shutil.rmtree(self._path) - if not os.path.isdir(self._path): - os.mkdir(self._path) - - def __str__(self): - return f"{self._path}(Cleanup: {self._cleanup}): {len(self.listdir())} Files" - - def __enter__(self): - if not os.path.isdir(self._path): - os.mkdir(self._path) - os.chdir(self._path) - return self - - def __exit__(self, exception_type, exception_value, traceback): - os.chdir(self._cwd) - self.clean_up() - - def __del__(self): - self.clean_up() - - def clean_up(self): - """Clean up the dir""" - - if self._cleanup and os.path.isdir(self._path): - shutil.rmtree(self._path) - - def add_file(self, name: str, contains: str) -> str: - """Add a file under the folder - - Parameters - ---------- - name: str - The name of the file. - contains: str - The contains of the file. - - Returns - ------- - path: str - The abs file path. - """ - - file_path = self.relpath(name) - base_dir = os.path.dirname(name) - if base_dir and not os.path.isdir(base_dir): - os.makedirs(base_dir) - with open(file_path, "w") as f: - f.write(contains) - return file_path - - def move(self, src_path: str, dst_path: Optional[str] = None): - """Move a file or folder to another folder - - Parameters - ---------- - src_path: str - The name of the source file or folder. - dst_path: str - The target file name or folder path. - - Returns - ------- - path: str - The abs file path. - """ - - if src_path != os.path.abspath(src_path): - src_path = os.path.join(self.relpath(src_path)) - assert os.path.isfile(src_path), f"Source path {src_path} not exist" - if not dst_path: - dst_path = self.relpath(os.path.basename(src_path)) - if dst_path != os.path.abspath(dst_path): - dst_path = self.relpath(dst_path) - os.rename(src_path, dst_path) - return dst_path - - def copy(self, src_path: str, dst_path: Optional[str] = None) -> str: - """Copy a file to another folder - - Parameters - ---------- - src_path: str - The name of the source file or folder. - dst_path: str - The target file name or folder path. - - Returns - ------- - path: str - The abs file path. - """ - - if not src_path: - return None - if src_path != os.path.abspath(src_path): - src_path = os.path.join(self.relpath(src_path)) - assert os.path.exists(src_path), f"Source path {src_path} not exist" - if not dst_path: - dst_path = self.relpath(os.path.basename(src_path)) - if dst_path != os.path.abspath(dst_path): - dst_path = self.relpath(dst_path) - if os.path.isfile(src_path): - shutil.copy2(src_path, dst_path) - else: - if os.path.isdir(dst_path): - shutil.rmtree(dst_path) - shutil.copytree(src_path, dst_path) - return dst_path - - def copy_to(self, dst_path: str): - """Copy dir to another folder - - Parameters - ---------- - dst_path: str - The target folder path. - - Returns - ------- - path: str - The abs file path. - """ - - return self.copy(self._path, dst_path) - - def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = False) -> Any: - """Add a dir under the folder - - Parameters - ---------- - name: str - The name of the file. - keep_history: bol - Whether to keep history. - cleanup: bool - Whether to clean up before exit. - - - Returns - ------- - dir: MSCDirectory - The created dir. - """ - - dir_path = self.relpath(name) - if os.path.isfile(dir_path): - os.remove(dir_path) - return self.__class__(dir_path, keep_history=keep_history, cleanup=cleanup) - - def relpath(self, name: str, keep_history: bool = True) -> str: - """Relative path in dir - - Parameters - ---------- - name: str - The name of the file. - - Returns - ------- - path: str - The concatenated path. - """ - - f_path = os.path.join(self._path, name) - if os.path.isfile(f_path) and not keep_history: - os.remove(f_path) - if os.path.isdir(f_path) and not keep_history: - shutil.rmtree(f_path) - return f_path - - def listdir(self, as_abs: bool = False) -> List[str]: - """List contents in the dir. - - Parameters - ---------- - as_abs: bool - Whether to show abs path. - - Returns - ------- - names: list - The content of directory - """ - - if not os.path.isdir(self._path): - return [] - if as_abs: - return [os.path.join(self._path, f) for f in os.listdir(self._path)] - return os.listdir(self._path) - - def finalize(self): - """Finalize the directory""" - - if not os.path.isdir(self._path): - return self._path - - def _remove_empty(path: str): - sub_paths = [os.path.join(path, f) for f in os.listdir(path)] - for s_path in sub_paths: - if not os.path.isdir(s_path): - continue - if len(os.listdir(s_path)) == 0: - shutil.rmtree(s_path) - else: - _remove_empty(s_path) - if len(os.listdir(path)) == 0: - shutil.rmtree(path) - return path - - return _remove_empty(self._path) - - def destory(self): - """Destory the dir.""" - - if os.path.isdir(self._path): - shutil.rmtree(self._path) - - @property - def path(self): - return self._path - - -def msc_dir( - path: Optional[str] = None, keep_history: bool = True, cleanup: bool = False -) -> MSCDirectory: - """Create MSCDirectory - - Parameters - ---------- - path: str - The path of the dir. - keep_history: bool - Whether to remove files before start. - cleanup: bool - Whether to clean up before exit. - - Returns - ------- - dir: MSCDirectory - The created dir. - """ - - return MSCDirectory(path, keep_history, cleanup) - - -def set_workspace( - path: Union[str, MSCDirectory] = None, keep_history: bool = True, cleanup: bool = False -) -> MSCDirectory: - """Create MSCDirectory as worksapce and set to map - - Parameters - ---------- - path: str - The path of the dir. - keep_history: bool - Whether to remove files before start. - cleanup: bool - Whether to clean up before exit. - - Returns - ------- - dir: MSCDirectory - The created dir. - """ - - if isinstance(path, MSCDirectory): - MSCMap.set(MSCKey.WORKSPACE, path) - return path - path = path or "msc_workspace" - workspace = MSCDirectory(path, keep_history, cleanup) - MSCMap.set(MSCKey.WORKSPACE, workspace) - return workspace - - -def get_workspace() -> MSCDirectory: - """Get workspace from MSCMap - - Returns - ------- - dir: MSCDirectory - The worksapce dir. - """ - - workspace = MSCMap.get(MSCKey.WORKSPACE) - assert workspace, "Can not find workspace, please call set_workspace" - return workspace - - -class ChangeWorkspace: - """Change the workspace - - Parameters - ---------- - new_workspace: MSCDirectory - The new workspace. - """ - - def __init__(self, new_workspace: MSCDirectory): - self._src_workspace = get_workspace() - self._new_workspace = new_workspace - - def __enter__(self): - set_workspace(self._new_workspace) - - def __exit__(self, exception_type, exception_value, traceback): - set_workspace(self._src_workspace) - - -def change_workspace(new_workspace: MSCDirectory): - """Change the workspace - - Parameters - ---------- - new_workspace: MSCDirectory - The new workspace. - """ - - return ChangeWorkspace(new_workspace) - - -def get_workspace_subdir( - name: Optional[str] = None, keep_history: bool = True, cleanup: bool = False -) -> MSCDirectory: - """Create sub dir for workspace - - Parameters - ---------- - name: str - The sub dir name under workspace. - keep_history: bool - Whether to remove files before start. - cleanup: bool - Whether to clean up before exit. - - Returns - ------- - dir: MSCDirectory - The created dir. - """ - - return get_workspace().create_dir(name, keep_history, cleanup) - - -def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = True) -> str: - """Change path to abs path - - Parameters - ---------- - path: str - The path of the file. - root_dir: MSCDirectory - Root dir to save the file. - keep_history: bool - Whether to remove files before start. - - Returns - ------- - abs_path: str - The abspath. - """ - - root_dir = root_dir or get_workspace() - if os.path.abspath(path) == path: - return path - return root_dir.relpath(path, keep_history) - - -def pack_folder(path: str, dst: Optional[str] = None, style="tar.gz"): - """Pack the folder - - Parameters - ---------- - path: str - The path of the folder. - dst: str - The pakced path. - style: str - The pack style. - - Returns - ------- - pack_path: str - The packed path. - """ - - dst = dst or path + "." + style - root = os.path.dirname(path) - if style == "tar.gz": - cmd = f"tar --exculde={dst} -zcvf {dst} {path} && rm -rf {path}" - else: - raise NotImplementedError(f"Pack style {style} is not supported") - if root: - with msc_dir(root): - retcode = subprocess.call(cmd, shell=True) - else: - retcode = subprocess.call(cmd, shell=True) - assert retcode == 0, f"Failed to pack the folder {path}->{dst}({style}): {retcode}" - return dst - - -def unpack_folder(path: str, dst: Optional[str] = None, style="tar.gz"): - """UnPack the folder - - Parameters - ---------- - path: str - The path of the folder. - dst: str - The pakced path. - style: str - The pack style. - - Returns - ------- - pack_path: str - The packed path. - """ - - dst = dst or path.split(".")[0] - root = os.path.dirname(path) - if style == "tar.gz": - cmd = f"tar -zxvf {path} {dst}" - else: - raise NotImplementedError(f"Pack style {style} is not supported") - if root: - with msc_dir(root): - retcode = subprocess.call(cmd, shell=True) - else: - retcode = subprocess.call(cmd, shell=True) - assert retcode == 0, f"Failed to unpack the folder {path}->{dst}({style}): {retcode}" - return dst - - -get_build_dir = partial(get_workspace_subdir, name="Build") -get_cache_dir = partial(get_workspace_subdir, name="Cache") -get_config_dir = partial(get_workspace_subdir, name="Config") -get_dataset_dir = partial(get_workspace_subdir, name="Dataset") -get_gym_dir = partial(get_workspace_subdir, name="Gym") -get_info_dir = partial(get_workspace_subdir, name="Info") -get_output_dir = partial(get_workspace_subdir, name="Output") -get_visual_dir = partial(get_workspace_subdir, name="Visual") -get_weights_dir = partial(get_workspace_subdir, name="Weights") diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py deleted file mode 100644 index 1b5b428c7207..000000000000 --- a/python/tvm/contrib/msc/core/utils/info.py +++ /dev/null @@ -1,431 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 -"""tvm.contrib.msc.core.utils.info""" - -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -from packaging.version import parse - -import tvm -import tvm.testing -from tvm.contrib.msc.core import _ffi_api - -from .namespace import MSCFramework - - -class MSCArray: - """MSC wrapper for array like object - - Parameters - ---------- - data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... - The data object. - """ - - def __init__(self, data: Any): - self._meta_data = data - self._framework, self._type, self._device = self._analysis(data) - - def __str__(self): - return f"<{self._framework} @{self._device}>{self.abstract()}" - - def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: - if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): - return MSCFramework.MSC, "list", "cpu" - if isinstance(data, np.ndarray): - return MSCFramework.MSC, "tensor", "cpu" - if isinstance(data, tvm.runtime.Tensor): - device = tvm.runtime.Device._DEVICE_TYPE_TO_NAME[data.device.dlpack_device_type()] - if data.device.index: - device += f":{data.device.index}" - return MSCFramework.TVM, "tensor", device - if isinstance(data, tvm.relax.Var): - return MSCFramework.TVM, "var", "cpu" - try: - import torch # pylint: disable=import-outside-toplevel - - if isinstance(data, torch.Tensor): - ref_dev = data.device - if ref_dev.index: - device = f"{ref_dev.type}:{ref_dev.index}" - else: - device = ref_dev.type - return MSCFramework.TORCH, "tensor", device - except: # pylint: disable=bare-except - pass - - raise Exception(f"Unkonwn data {data}({type(data)})") - - def abstract(self) -> str: - """Get abstract describe of the data""" - - data = self._to_tensor() - prefix = "[{},{}]".format(";".join([str(s) for s in data.shape]), data.dtype.name) - if data.size < 10: - return "{} {}".format(prefix, ",".join([str(i) for i in data.flatten()])) - return f"{prefix} Max {data.max():g}, Min {data.min():g}, Avg {data.sum() / data.size:g}" - - def _to_tensor(self) -> np.ndarray: - """Cast array like object to np.ndarray - - Returns - ------- - data: np.ndarray - The data as np.ndarray. - """ - - if self._framework == MSCFramework.MSC: - if self._type == "list": - return np.array(self._meta_data) - return self._meta_data - if self._framework == MSCFramework.TVM: - if self._type == "var": - shape = [int(s) for s in self._meta_data.struct_info.shape] - return np.zeros(shape, dtype=self._meta_data.struct_info.dtype) - return self._meta_data.numpy() - if self._framework == MSCFramework.TORCH: - return self._meta_data.detach().cpu().numpy() - return self._meta_data - - def _to_device(self, device: str) -> Any: - """Cast array like object to array like object - - Parameters - ---------- - device: str - The device for tensor. - - Returns - ------- - output: - The output as framework tensor. - """ - - if self._device == device: - return self._meta_data - if self._framework == MSCFramework.TORCH: - return self._meta_data.to(self.get_device(device)) - if self._framework == MSCFramework.TVM: - return tvm.runtime.tensor(self._cast_data(), device=self.get_device(device)) - return self._meta_data - - def cast(self, framework: str, device: str = "cpu") -> Any: - """Cast array like object to array like object - - Parameters - ---------- - framework: str - The target framework. - device: str - The device for tensor. - - Returns - ------- - output: - The output as framework tensor. - """ - - device = device or self._device - if framework == self._framework and device == self._device and self._type == "tensor": - return self._meta_data - if framework == self._framework: - return self._to_device(device) - data = self._to_tensor() - if framework == MSCFramework.TORCH: - import torch # pylint: disable=import-outside-toplevel - - return torch.from_numpy(data).to(self.get_device(device, framework)) - if framework == MSCFramework.TVM: - return tvm.runtime.tensor(data, device=self.get_device(device, framework)) - return data - - def get_device(self, device: str, framework: Optional[str] = None) -> Any: - """Change device from name to device obj - - Parameters - ---------- - device: str - The device for tensor. - framework: str - The target framework. - - Returns - ------- - device: any - The device object. - """ - - framework = framework or self._framework - if framework == MSCFramework.TVM: - if device.startswith("cpu"): - return tvm.cpu() - if device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - return tvm.cuda(dev_id) - raise TypeError("Unexpected tvm device " + str(device)) - if framework == MSCFramework.TORCH: - import torch # pylint: disable=import-outside-toplevel - - return torch.device(device) - return device - - @classmethod - def is_array(cls, data: Any) -> bool: - """Check if the data is array like - - Parameters - ---------- - data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... - The data object. - - Returns - ------- - is_array: bool - Whether the data is array like. - """ - - normal_types = (np.ndarray, tvm.runtime.Tensor, tvm.relax.Var) - if isinstance(data, normal_types): - return True - if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): - return True - try: - import torch # pylint: disable=import-outside-toplevel - - if isinstance(data, torch.Tensor): - return True - except: # pylint: disable=bare-except - pass - - return False - - @property - def framework(self): - return self._framework - - @property - def device(self): - return self._device - - @property - def type(self): - return self._type - - -def is_array(data: Any) -> bool: - """Check if the data is array - - Parameters - ---------- - data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... - The data object. - - Returns - ------- - is_array: bool - Whether the data is array. - """ - - return MSCArray.is_array(data) - - -def cast_array(data: Any, framework: str = MSCFramework.MSC, device: str = "cpu") -> Any: - """Cast array like object to np.ndarray - - Parameters - ---------- - data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... - The data object. - framework: str - The target framework. - device: str - The device for tensor. - - Returns - ------- - output: np.ndarray - The output as numpy array or framework tensor(if given). - """ - - assert MSCArray.is_array(data), f"{data} is not array like" - return MSCArray(data).cast(framework, device) - - -def inspect_array(data: Any, as_str: bool = True) -> Union[Dict[str, Any], str]: - """Inspect the array - - Parameters - ---------- - data: array like - The data to inspect - as_str: bool - Whether inspect the array as string. - - Returns - ------- - info: dict - The data info. - """ - - if not MSCArray.is_array(data): - return str(data) - if as_str: - return str(MSCArray(data)) - data = cast_array(data) - return { - "shape": list(data.shape), - "dtype": data.dtype.name, - "max": float(data.max()), - "min": float(data.min()), - "avg": float(data.sum() / data.size), - } - - -def compare_arrays( - golden: Dict[str, Any], - datas: Dict[str, Any], - atol: float = 1e-2, - rtol: float = 1e-2, - report_detail: bool = False, -) -> dict: - """Compare elements in array - - Parameters - ---------- - golden: dict - The golden datas. - datas: dict - The datas to be compared. - atol: float - The atol for compare. - rtol: float - The rtol for compare. - report_detail: bool - Whether to report detail - - Returns - ------- - report: dict - The compare results. - """ - - assert golden.keys() == datas.keys(), ( - f"golden {golden.keys()} and datas {datas.keys()} mismatch" - ) - golden = {k: cast_array(v) for k, v in golden.items()} - datas = {k: cast_array(v) for k, v in datas.items()} - report = {"total": 0, "passed": 0, "info": {}} - - def _add_report(name: str, gol: Any, data: Any, passed: bool): - diff = MSCArray(gol - data) - if passed: - if report_detail: - report["info"][name] = { - "data": MSCArray(data).abstract(), - "d_pass": diff.abstract(), - } - else: - report["info"][name] = f"d_pass: {diff.abstract()}" - report["passed"] += 1 - else: - if report_detail: - report["info"][name] = { - "gold": MSCArray(gol).abstract(), - "data": MSCArray(data).abstract(), - "d_fail": diff.abstract(), - } - else: - report["info"][name] = f"d_fail: {diff.abstract()}" - - for name, gol in golden.items(): - report["total"] += 1 - data = datas[name] - if list(gol.shape) != list(data.shape): - report["info"][name] = f"fail: shape mismatch [G]{gol.shape} vs [D]{data.shape}" - continue - if gol.dtype != data.dtype: - report["info"][name] = f"fail: dtype mismatch [G]{gol.dtype} vs [D]{data.dtype}" - continue - if gol.dtype.name in ("int32", "int64"): - passed = np.abs(gol - data).max() == 0 - _add_report(name, gol, data, passed) - continue - try: - tvm.testing.assert_allclose(gol, data, rtol=rtol, atol=atol, verbose=False) - _add_report(name, gol, data, True) - except: # pylint: disable=bare-except - _add_report(name, gol, data, False) - return report - - -def get_version(framework: str) -> List[int]: - """Get the version list of framework. - - Parameters - ---------- - framework: string - Should be from MSCFramework. - - Returns - ------- - version: list - The version in . - """ - - try: - if framework in (MSCFramework.MSC, MSCFramework.TVM): - raw_version = tvm.__version__ - elif framework == MSCFramework.TORCH: - import torch # pylint: disable=import-outside-toplevel - - raw_version = torch.__version__ - elif framework == MSCFramework.TENSORFLOW: - import tensorflow # pylint: disable=import-outside-toplevel - - raw_version = tensorflow.__version__ - if framework == MSCFramework.TENSORRT: - raw_version = ".".join( - [str(v) for v in tvm.get_global_func("relax.get_tensorrt_version")()] - ) - else: - raw_version = "1.0.0" - except: # pylint: disable=bare-except - raw_version = "1.0.0" - version = parse(raw_version or "1.0.0") - return [version.major, version.minor, version.micro] - - -def compare_version(given_version: List[int], target_version: List[int]) -> int: - """Compare version - - Parameters - ---------- - given_version: list - The version in . - - target_version: list - The version in . - - Returns - ------- - compare_res: int - The compare result: 0 for same version, 1 for greater version, -1 for less version - """ - - return int(_ffi_api.CompareVersion(given_version, target_version)) diff --git a/python/tvm/contrib/msc/core/utils/log.py b/python/tvm/contrib/msc/core/utils/log.py deleted file mode 100644 index 278e231a5bce..000000000000 --- a/python/tvm/contrib/msc/core/utils/log.py +++ /dev/null @@ -1,190 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.core.utils.log""" - -import logging -import os -from typing import Optional, Union - -from .file import get_workspace -from .namespace import MSCKey, MSCMap - - -class IOLogger: - """IO Logger for MSC""" - - def __init__(self): - self._printers = { - "red": (lambda m: print(f"\033[91m {m}\033[00m")), - "green": (lambda m: print(f"\033[92m {m}\033[00m")), - "yellow": (lambda m: print(f"\033[93m {m}\033[00m")), - "purple": (lambda m: print(f"\033[95m {m}\033[00m")), - "cyan": (lambda m: print(f"\033[96m {m}\033[00m")), - "gray": (lambda m: print(f"\033[97m {m}\033[00m")), - "black": (lambda m: print(f"\033[98m {m}\033[00m")), - } - - def info(self, msg): - self._printers["green"]("[MSC_INFO] " + str(msg)) - - def debug(self, msg): - self._printers["green"]("[MSC_DEBUG] " + str(msg)) - - def warning(self, msg): - self._printers["yellow"]("[MSC_WARNING] " + str(msg)) - - def error(self, msg): - self._printers["red"]("[MSC_ERROR] " + str(msg)) - raise Exception(msg) - - -def create_file_logger( - level: Union[str, int] = logging.INFO, path: Optional[str] = None -) -> logging.Logger: - """Create file logger - - Parameters - ---------- - level: logging level - The logging level. - path: str - The file path. - - Returns - ------- - logger: logging.Logger - The logger. - """ - - if isinstance(level, str): - if level.startswith("debug"): - level = logging.DEBUG - elif level == "info": - level = logging.INFO - elif level == "warn": - level = logging.WARN - elif level == "error": - level = logging.ERROR - elif level == "critical": - level = logging.CRITICAL - else: - raise Exception("Unexcept verbose {}, should be debug| info| warn") - - path = path or os.path.join(get_workspace(), "MSC_LOG") - log_name = os.path.basename(path) - logger = logging.getLogger(log_name) - logger.setLevel(level) - if any(isinstance(h, logging.FileHandler) and h.baseFilename == path for h in logger.handlers): - return logger - formatter = logging.Formatter( - "%(asctime)s %(filename)s[ln:%(lineno)d]<%(levelname)s> %(message)s" - ) - handlers = [ - logging.FileHandler(path, mode="a", encoding=None, delay=False), - logging.StreamHandler(), - ] - for handler in handlers: - handler.setLevel(level) - handler.setFormatter(formatter) - logger.addHandler(handler) - return logger - - -def set_global_logger( - level: Union[str, int] = logging.INFO, path: Optional[str] = None -) -> logging.Logger: - """Create file logger and set to global - - Parameters - ---------- - level: logging level - The logging level. - path: str - The file path. - - Returns - ------- - logger: logging.Logger - The logger. - """ - - logger = create_file_logger(level, path) - MSCMap.set(MSCKey.GLOBALE_LOGGER, logger) - return logger - - -def get_global_logger() -> logging.Logger: - """Get the global logger - - Returns - ------- - logger: logging.Logger - The logger. - """ - - if not MSCMap.get(MSCKey.GLOBALE_LOGGER): - MSCMap.set(MSCKey.GLOBALE_LOGGER, IOLogger()) - return MSCMap.get(MSCKey.GLOBALE_LOGGER) - - -def get_log_file(logger: logging.Logger) -> str: - """Get the log file from logger - - Parameters - ---------- - logger: logging.Logger - The logger. - - Returns - ------- - log_file: str - The log file. - """ - - for log_h in logger.handlers: - if isinstance(log_h, logging.FileHandler): - return log_h.baseFilename - return None - - -def remove_loggers(): - """Remove the logger handlers""" - - logger = MSCMap.get(MSCKey.GLOBALE_LOGGER) - if logger: - logger.handlers.clear() - - -def split_line(msg: str, symbol: str = "#", width: int = 100) -> str: - """Mark message to split line - - Parameters - ---------- - msg: str - The message. - symbol: str - The split symbol. - width: int - The line width. - - Returns - ------- - split_line: str - The split line with message. - """ - - return f"\n{20 * symbol}{msg.center(width - 40)}{20 * symbol}" diff --git a/python/tvm/contrib/msc/core/utils/message.py b/python/tvm/contrib/msc/core/utils/message.py deleted file mode 100644 index 830ef23cbc3b..000000000000 --- a/python/tvm/contrib/msc/core/utils/message.py +++ /dev/null @@ -1,173 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF012 -"""tvm.contrib.msc.core.utils.message""" - -import datetime -import logging -from typing import List, Optional, Tuple - -from .arguments import dump_dict, map_dict -from .log import get_global_logger, split_line -from .namespace import MSCKey, MSCMap - - -class MSCStage: - """Enum all msc stage names""" - - SETUP = "setup" - PREPARE = "prepare" - PARSE = "parse" - PRUNE = "prune" - QUANTIZE = "quantize" - DISTILL = "distill" - TRACK = "track" - BASELINE = "baseline" - OPTIMIZE = "optimize" - COMPILE = "compile" - SUMMARY = "summary" - EXPORT = "export" - ALL = [ - SETUP, - PREPARE, - PARSE, - PRUNE, - QUANTIZE, - DISTILL, - TRACK, - BASELINE, - OPTIMIZE, - COMPILE, - SUMMARY, - EXPORT, - ] - - @classmethod - def all_stages(cls) -> List[str]: - """Get all stage names""" - return cls.ALL - - -def time_stamp(stage: str, log_stage: bool = True, logger: Optional[logging.Logger] = None): - """Mark the stamp and record time. - - Parameters - ---------- - stage: str - The stage name. - log_stage: bool - Whether to log the stage. - logger: logging.Logger - The logger. - """ - - logger = logger or get_global_logger() - time_stamps = MSCMap.get(MSCKey.TIME_STAMPS, []) - time_stamps.append((stage, datetime.datetime.now())) - MSCMap.set(MSCKey.TIME_STAMPS, time_stamps) - if stage in MSCStage.all_stages(): - if log_stage: - last_stage = MSCMap.get(MSCKey.MSC_STAGE) - if last_stage: - end_msg = f"End {last_stage.upper()}" - logger.info("%s\n", split_line(end_msg)) - start_msg = f"Start {stage.upper()}" - logger.info(split_line(start_msg)) - MSCMap.set(MSCKey.MSC_STAGE, stage.upper()) - elif log_stage: - start_msg = f"Start {stage}" - logger.debug(split_line(start_msg, "+")) - - -def get_duration() -> dict: - """Get duration of the whole process. - - Returns - ------- - duration: dict - The duration of the process. - """ - - time_stamps = MSCMap.get(MSCKey.TIME_STAMPS, []) - if not time_stamps: - return {} - - def _get_duration(idx): - return (time_stamps[idx + 1][1] - time_stamps[idx][1]).total_seconds() - - def _set_stage(stage: str, info: Tuple[float, dict], collect: dict): - if "." in stage: - main_stage, sub_stage = stage.split(".", 1) - _set_stage(sub_stage, info, collect.setdefault(main_stage, {})) - else: - collect[stage] = info - - def _set_total(collect: dict): - collect["total"] = 0 - for dur in collect.values(): - collect["total"] += _set_total(dur) if isinstance(dur, dict) else dur - return collect["total"] - - duration, depth = {}, 1 - left_durs = {time_stamps[i][0]: _get_duration(i) for i in range(len(time_stamps) - 1)} - while left_durs: - current_durs = {s: dur for s, dur in left_durs.items() if len(s.split(".")) == depth} - left_durs = {k: v for k, v in left_durs.items() if k not in current_durs} - for stage, dur in current_durs.items(): - info = {"init": dur} if any(s.startswith(stage + ".") for s in left_durs) else dur - _set_stage(stage, info, duration) - depth += 1 - - _set_total(duration) - - def _to_str(dur): - if not isinstance(dur, float): - return dur - return "{:.2f} s({:.2f}%)".format(dur, dur * 100 / duration["total"]) - - return map_dict(duration, _to_str) - - -def msg_block(title: str, msg: str, width: int = 100, symbol: str = "-"): - """Log message in block format - - Parameters - ---------- - title: str - The title of the block - msg: str - The message to log. - width: int - The max width of block message - symbol: str - The split symbol. - - Returns - ------- - msg: str - The block message. - """ - - if isinstance(msg, dict): - msg = dump_dict(msg, "table:" + str(width)) - return f"{split_line(title, symbol)}\n{msg}" - - -def current_stage(): - """Get the current stage""" - - return MSCMap.get(MSCKey.MSC_STAGE, "Unknown") diff --git a/python/tvm/contrib/msc/core/utils/namespace.py b/python/tvm/contrib/msc/core/utils/namespace.py deleted file mode 100644 index e3e663b8f2fc..000000000000 --- a/python/tvm/contrib/msc/core/utils/namespace.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF012 -"""tvm.contrib.msc.core.utils.namespace""" - -import copy -from typing import Any, Optional - - -class MSCMap: - """Global Namespace map for MSC""" - - MAP = {} - - @classmethod - def set(cls, key: str, value: Any): - cls.MAP[key] = value - - @classmethod - def get(cls, key: str, default: Optional[Any] = None): - return cls.MAP.get(key, default) - - @classmethod - def clone(cls, key: str, default: Optional[Any] = None): - return copy.deepcopy(cls.get(key, default)) - - @classmethod - def delete(cls, key: str): - if key in cls.MAP: - return cls.MAP.pop(key) - return None - - @classmethod - def contains(cls, key: str): - return key in cls.MAP - - @classmethod - def reset(cls): - cls.MAP = {} - - -class MSCKey: - """Keys for the MSCMap""" - - WORKSPACE = "workspace" - VERBOSE = "verbose" - GLOBALE_LOGGER = "global_logger" - MSC_STAGE = "msc_stage" - TIME_STAMPS = "time_stamps" - - PRUNERS = "pruners" - QUANTIZERS = "quantizers" - DISTILLERS = "distillers" - TRACKERS = "trackers" - - FUSED_CNT = "fused_cnt" - ROOT_MARK = "$" - - -class MSCFramework: - """Framework type for the MSC""" - - MSC = "msc" - TVM = "tvm" - TORCH = "torch" - TENSORFLOW = "tensorflow" - TENSORRT = "tensorrt" diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py deleted file mode 100644 index 23b9238baa5c..000000000000 --- a/python/tvm/contrib/msc/core/utils/register.py +++ /dev/null @@ -1,404 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF012 -"""tvm.contrib.msc.core.utils.register""" - -from typing import Any, Optional - -from .namespace import MSCFramework - - -class MSCRegistery: - """The registery for MSC""" - - REGISTERY = {} - MSC_FUNCS = "msc_funcs" - TOOL_CLASSES = "tool_classes" - TOOL_METHODS = "tool_methods" - TOOL_CONFIGERS = "tool_configers" - GYM_CONFIGERS = "gym_configers" - GYM_CONTROLLERS = "gym_controllers" - GYM_OBJECTS = "gym_objects" - GYM_METHODS = "gym_agents_method" - RUNNER_HOOKS = "runner_hooks" - - @classmethod - def register(cls, key: str, value: Any): - cls.REGISTERY[key] = value - return value - - @classmethod - def unregister(cls, key: str): - if key in cls.REGISTERY: - return cls.REGISTERY.pop(key) - return None - - @classmethod - def get(cls, key: str, default: Optional[Any] = None) -> Any: - return cls.REGISTERY.get(key, default) - - @classmethod - def contains(cls, key: str): - return key in cls.REGISTERY - - @classmethod - def reset(cls): - cls.REGISTERY = {} - - -def register_global_func(name: str, func: callable, framework: str = MSCFramework.MSC): - """Register a func for framework. - - Parameters - ---------- - name: string - The name for the func. - func: callable - The function to be registered. - framework: string - Should be from MSCFramework. - """ - - funcs = MSCRegistery.get(MSCRegistery.MSC_FUNCS, {}) - if framework not in funcs: - funcs[framework] = {} - funcs[framework][name] = func - MSCRegistery.register(MSCRegistery.MSC_FUNCS, funcs) - - -def get_registered_func(name: str, framework: str = MSCFramework.MSC): - """Get the registered func of framework. - - Parameters - ---------- - name: string - The name for the func. - framework: string - Should be from MSCFramework. - - Returns - ------- - func: callable - The registered function. - """ - - funcs = MSCRegistery.get(MSCRegistery.MSC_FUNCS, {}) - if framework not in funcs: - return None - return funcs[framework].get(name) - - -def register_tool(tool: Any): - """Register a tool class. - - Parameters - ---------- - tool: class - The tool class to be registered. - """ - - for key in ["framework", "tool_type", "tool_style"]: - assert hasattr(tool, key), f"{key} should be given to register tool" - tools_classes = MSCRegistery.get(MSCRegistery.TOOL_CLASSES, {}) - col = tools_classes.setdefault(tool.framework(), {}).setdefault(tool.tool_type(), {}) - col[tool.tool_style()] = tool - MSCRegistery.register(MSCRegistery.TOOL_CLASSES, tools_classes) - return tool - - -def get_registered_tool(framework: str, tool_type: str, tool_style: str) -> Any: - """Get the registered tool class. - - Parameters - ---------- - framework: string - Should be from MSCFramework. - tool_type: string - The type of the tool prune| quantize| distill| debug. - tool_style: string - The style of the tool. - - Returns - ------- - tool: class - The registered tool class. - """ - - tools_classes = MSCRegistery.get(MSCRegistery.TOOL_CLASSES, {}) - if tool_style == "all": - return tools_classes.get(framework, {}).get(tool_type, {}) - return tools_classes.get(framework, {}).get(tool_type, {}).get(tool_style) - - -def register_tool_method(method: Any): - """Register a tool method. - - Parameters - ---------- - method: class - The method class. - """ - - for key in ["framework", "tool_type", "method_style"]: - assert hasattr(method, key), f"{key} should be given to register tool method" - tool_methods = MSCRegistery.get(MSCRegistery.TOOL_METHODS, {}) - col = tool_methods.setdefault(method.framework(), {}).setdefault(method.tool_type(), {}) - col[method.method_style()] = method - MSCRegistery.register(MSCRegistery.TOOL_METHODS, tool_methods) - return method - - -def get_registered_tool_method( - framework: str, tool_type: str, method_style: str = "default" -) -> Any: - """Get the registered tool method. - - Parameters - ---------- - framework: string - Should be from MSCFramework. - tool_type: string - The type of the tool prune| quantize| distill| debug. - method_style: string - The style of the method. - - Returns - ------- - method_cls: class - The method class. - """ - - tool_methods = MSCRegistery.get(MSCRegistery.TOOL_METHODS, {}) - return tool_methods.get(framework, {}).get(tool_type, {}).get(method_style) - - -def register_tool_configer(configer: Any): - """Register a tool configer. - - Parameters - ---------- - configer: class - The configer class. - """ - - for key in ["tool_type", "config_style"]: - assert hasattr(configer, key), f"{key} should be given to register tool configer" - tool_configers = MSCRegistery.get(MSCRegistery.TOOL_CONFIGERS, {}) - col = tool_configers.setdefault(configer.tool_type(), {}) - col[configer.config_style()] = configer - MSCRegistery.register(MSCRegistery.TOOL_CONFIGERS, tool_configers) - return configer - - -def get_registered_tool_configer(tool_type: str, config_style: str) -> Any: - """Get the registered configer. - - Parameters - ---------- - tool_type: string - The type of tool. - config_style: string - The style of tool. - - Returns - ------- - configer: class - The configer class. - """ - - tool_configers = MSCRegistery.get(MSCRegistery.TOOL_CONFIGERS, {}) - return tool_configers.get(tool_type, {}).get(config_style) - - -def register_gym_configer(configer: Any): - """Register a gym configer. - - Parameters - ---------- - configer: class - The configer class. - """ - - assert hasattr(configer, "config_type"), "config_type should be given to register configer" - gym_configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) - gym_configers[configer.config_type()] = configer - MSCRegistery.register(MSCRegistery.GYM_CONFIGERS, gym_configers) - return configer - - -def get_registered_gym_configer(config_type: str) -> Any: - """Get the registered configer. - - Parameters - ---------- - config_type: string - The type of configer. - - Returns - ------- - configer: class - The configer class. - """ - - gym_configers = MSCRegistery.get(MSCRegistery.GYM_CONFIGERS, {}) - return gym_configers.get(config_type) - - -def register_gym_controller(controller: Any): - """Register a gym controller. - - Parameters - ---------- - controller: class - The controller class. - """ - - assert hasattr(controller, "control_type"), ( - "control_type should be given to register controller" - ) - gym_controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) - gym_controllers[controller.control_type()] = controller - MSCRegistery.register(MSCRegistery.GYM_CONTROLLERS, gym_controllers) - return controller - - -def get_registered_gym_controller(control_type: str) -> Any: - """Get the registered controller. - - Parameters - ---------- - control_type: string - The type of controller. - - Returns - ------- - controller: class - The controller class. - """ - - gym_controllers = MSCRegistery.get(MSCRegistery.GYM_CONTROLLERS, {}) - return gym_controllers.get(control_type) - - -def register_gym_object(obj: Any): - """Register a gym object. - - Parameters - ---------- - obj: class - The object class. - """ - - for key in ["role", "role_type"]: - assert hasattr(obj, key), f"{key} should be given to register gym object" - gym_objects = MSCRegistery.get(MSCRegistery.GYM_OBJECTS, {}) - col = gym_objects.setdefault(obj.role(), {}) - col[obj.role_type()] = obj - MSCRegistery.register(MSCRegistery.GYM_OBJECTS, gym_objects) - return obj - - -def get_registered_gym_object(role: str, role_type: str) -> Any: - """Get the registered object. - - Parameters - ---------- - role: string - The role. - role_type: string - The type of the role. - - Returns - ------- - object: class - The object class. - """ - - gym_objects = MSCRegistery.get(MSCRegistery.GYM_OBJECTS, {}) - return gym_objects.get(role, {}).get(role_type) - - -def register_gym_method(method: Any): - """Register a gym method. - - Parameters - ---------- - method: class - The method class. - """ - - for key in ["role", "method_type"]: - assert hasattr(method, key), f"{key} should be given to register gym method" - gym_methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) - col = gym_methods.setdefault(method.role(), {}) - col[method.method_type()] = method - MSCRegistery.register(MSCRegistery.GYM_METHODS, gym_methods) - return method - - -def get_registered_gym_method(role: str, method_type: str) -> Any: - """Get the registered gym method. - - Parameters - ---------- - role: str - The role. - method_type: str - The type of method. - - Returns - ------- - method: class - The method class. - """ - - gym_methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) - return gym_methods.get(role, {}).get(method_type) - - -def register_runner_hook(hook: Any): - """Register a runner hook. - - Parameters - ---------- - hook: class - The hook class. - """ - - assert hasattr(hook, "name"), "name should be given to register hook" - hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) - hooks[hook.name()] = hook - MSCRegistery.register(MSCRegistery.RUNNER_HOOKS, hooks) - return hook - - -def get_registered_runner_hook(name: str) -> Any: - """Get the registered runner hook. - - Parameters - ---------- - name: str - The name hook. - - Returns - ------- - method: class - The method class. - """ - - hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) - return hooks.get(name) diff --git a/python/tvm/contrib/msc/framework/__init__.py b/python/tvm/contrib/msc/framework/__init__.py deleted file mode 100644 index fcdf0c886c24..000000000000 --- a/python/tvm/contrib/msc/framework/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm""" diff --git a/python/tvm/contrib/msc/framework/tensorflow/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/__init__.py deleted file mode 100644 index 5f52aeb42845..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow""" - -import tensorflow as tf - -try: - tf_v1 = tf.compat.v1 -except (ImportError, AttributeError): - tf_v1 = tf diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py deleted file mode 100644 index f7cd2ea43e3e..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.framework.tensorflow", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorflow/codegen/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/codegen/__init__.py deleted file mode 100644 index 66399e6400e4..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/codegen/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.codegen""" - -from .codegen import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py deleted file mode 100644 index 5ea57f390d7e..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF005 -"""tvm.contrib.msc.framework.tensorflow.codegen.codegen""" - -from typing import Any, Dict, Optional - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.codegen import CodeGen -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.framework.tensorflow import _ffi_api, tf_v1 - - -def to_tensorflow( - graph: MSCGraph, - weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, -) -> tf_v1.Graph: - """Change MSCGraph to tensorflow graph. - - Parameters - ---------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The parameters of the IRModule. - codegen_config: dict - The config for codegen. - print_config: dict - The config for print. - build_folder: MSCDirectory - The folder for saving scripts and datas. - plugin: PluginManager - The plugin manager. - - Returns - ------- - tf_graph: tf_v1.Graph - The tensorflow Graph. - """ - - def _save_weights(folder: msc_utils.MSCDirectory): - if weights: - with open(folder.relpath(graph.name + "_params.bin"), "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(weights)) - - inputs = [tf_v1.placeholder(i.dtype_name, i.get_shape(), i.alias) for i in graph.get_inputs()] - codegen = CodeGen( - graph, _ffi_api.GetTensorflowSources, codegen_config, print_config, build_folder - ) - model_args = inputs + [weights] - if plugin: - model_args = model_args + [plugin] - return codegen.load(model_args, pre_load=_save_weights) diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py deleted file mode 100644 index 419b4f8fef6f..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/frontend/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.frontend""" - -from .translate import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py deleted file mode 100644 index 4b7433e2babe..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/frontend/translate.py +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=unused-argument -"""tvm.contrib.msc.framework.torch.frontend.translate""" - -from typing import Dict, List, Optional, Tuple, Union - -import tvm -from tvm.contrib.msc.core.ir.graph import MSCGraph -from tvm.contrib.msc.framework.tensorflow import tf_v1 - - -def from_tensorflow( - graph_def: tf_v1.GraphDef, - shape_dict: Dict[str, List[int]], - outputs: List[str], - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - opt_config: Optional[Dict[str, str]] = None, - as_msc: bool = True, -) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.runtime.Tensor]]: - """Change tensorflow GraphDef to MSCGraph. - - Parameters - ---------- - graph_def: tf_v1.GraphDef - The graph define of tensorflow. - shape_dict: dict> - The shape dict of inputs. - outputs: list - The output names. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize before translate. - as_msc: bool - Set to to return msc graph, otherwise relax mod - - Returns - ------- - graph/mod: tvm.contrib.msc.core.ir.MSCGraph/tvm.IRModule - The translated graph/IRModule. - weights: dict of - The weights from the IRModule. - """ - - raise NotImplementedError("translate relax module from tensorflow is not implemented") diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py deleted file mode 100644 index de9dea8e4b07..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.runtime""" - -from .runner import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py deleted file mode 100644 index 6cf5ef5fac8f..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ /dev/null @@ -1,305 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=not-context-manager,unused-import -# ruff: noqa: F401 -"""tvm.contrib.msc.framework.tensorflow.runtime.runner""" - -import time -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -from tensorflow.python.client import device_lib -from tensorflow.python.ops import variables - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.runtime import ModelRunner -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.tensorflow import tf_v1, tools -from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow -from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow - - -class WrapSession(tf_v1.Session): - """Wrapped session for MSC""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._inputs, self._outputs = None, None - - def set_bindings(self, inputs: List[Dict[str, str]], outputs: List[Dict[str, str]]): - """Set inputs and outputs for session - - Parameters - ------- - inputs: list - The inputs info of the model. - outputs: list - The outputs info of the model. - """ - - self._inputs = inputs - self._outputs = outputs - - def run(self, fetches, *args, **kwargs): # pylint: disable=useless-parent-delegation - return super().run(fetches, *args, **kwargs) - - -class TensorflowRunner(ModelRunner): - """Runner of Tensorflow""" - - def setup(self) -> dict: - """Setup the runner - - Returns - ------- - info: dict - The setup info. - """ - - self._tf_graph = None - self._tf_outputs = None - self._session = None - return super().setup() - - def destory(self): - """Destory runner""" - - self._session.close() - self._tf_graph = None - self._tf_outputs = None - self._session = None - super().destory() - - def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> tf_v1.Graph: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - model: tf_v1.Graph - The runnable model - """ - - if self._tf_graph: - del self._tf_graph - self._tf_graph = tf_v1.Graph() - with self._tf_graph.as_default(): - self._tf_outputs = super()._generate_model(graphs, weights) - return self._tf_graph - - def _build_runnable(self, model: Any) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - - Returns - ------- - runnable: Any - The runnable - """ - - if self._session: - self._session.close() - del self._session - self._session = WrapSession(graph=self._tf_graph) - self._session.set_bindings(self.get_inputs(), self.get_outputs()) - with self._tf_graph.as_default(): - self._session.run(variables.global_variables_initializer()) - return self._session - - def _call_runnable( - self, runnable: WrapSession, inputs: Dict[str, np.ndarray], device: str - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Call the runnable to get outputs - - Parameters - ------- - runnable: WrapSession - The wrapped session. - inputs: dict - The inputs in dict. - device: str - The device. - - Returns - ------- - outputs: list or dict - The outputs in list or dict. - """ - - input_names = [i["name"] for i in self.get_inputs()] - feed_dict = {i + ":0": msc_utils.cast_array(inputs[i]) for i in input_names} - return runnable.run(self._tf_outputs, feed_dict) - - @property - def codegen_func(self): - return to_tensorflow - - @property - def framework(self): - return MSCFramework.TENSORFLOW - - @classmethod - def load_native(cls, model: Any, config: dict) -> Tuple[tf_v1.GraphDef, str, bool]: - """Load the native model - - Parameters - ------- - model: - The native model. - config: dict - The config for pipeline. - - Returns - ------- - model: tf_v1.GraphDef - The loaded native model. - device: str - The device of the model. - training: - Whether the model is for training. - """ - - if isinstance(model, tf_v1.GraphDef): - native_model = model - else: - raise NotImplementedError( - f"Load native model {model} with type {type(model)} is not supported" - ) - device_protos = device_lib.list_local_devices() - if any(dev.dlpack_device_type() == "GPU" for dev in device_protos): - device = "cuda" - else: - device = "cpu" - return native_model, device, False - - @classmethod - def run_native( - cls, - model: tf_v1.GraphDef, - inputs: Dict[str, np.ndarray], - input_names: List[str], - output_names: List[str], - warm_up: int = 10, - repeat: int = 0, - ) -> Tuple[Dict[str, np.ndarray], float]: - """Run the datas and get outputs - - Parameters - ------- - model: tf_v1.GraphDef - The graph def. - inputs: dict - The inputs in dict. - input_names: list - The input names. - output_names: list - The outut names. - warm_up: int - The warm_up num for profile. - repeat: int - The repeat num for profile. - - Returns - ------- - outputs: dict - The outputs in dict. - avg_time: float - The average time. - """ - - feed_dict = {i_name + ":0": inputs[i_name] for i_name in input_names} - with tf_v1.Graph().as_default(): - tf_v1.import_graph_def(model, name="") - with tf_v1.Session() as sess: - if repeat > 0: - for _ in range(warm_up): - outputs = sess.run(output_names, feed_dict) - start = time.time() - for _ in range(repeat): - outputs = sess.run(output_names, feed_dict) - avg_time = (time.time() - start) * 1000 / repeat - else: - outputs = sess.run(output_names, feed_dict) - avg_time = -1 - outputs = dict(zip(output_names, outputs)) - return outputs, avg_time - - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - config["parse"]["parser"] = from_tensorflow - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - { - "shape_dict": {i[0]: i[1] for i in config["inputs"]}, - "outputs": config["outputs"], - } - ) - config["parse"]["parse_config"] = parse_config - return config - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - device_protos = device_lib.list_local_devices() - return any(dev.dlpack_device_type() == "GPU" for dev in device_protos) - return False diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py deleted file mode 100644 index e044077367fa..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools""" - -from .prune import * -from .quantize import * -from .distill import * -from .track import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py deleted file mode 100644 index 1c89122c0a7d..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.distill""" - -from .distiller import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py deleted file mode 100644 index 6ae2b73e3967..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/distill/distiller.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.distill.distiller""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.distill import BaseDistiller -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorflowDistillerFactory: - """Distiller factory for tensorflow""" - - def create(self, base_cls: BaseDistiller) -> BaseDistiller: - """Create adaptive distiller - - Parameters - ---------- - base_cls: BaseDistiller - The base distiller class - - Returns - ------- - distiller_cls: BaseDistiller - The distiller class. - """ - - @msc_utils.register_tool - class Distiller(base_cls): - """Adaptive distiller for tensorflow""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORFLOW - - return Distiller - - -factory = TensorflowDistillerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py deleted file mode 100644 index 8bdd61d3aa12..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.prune""" - -from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py deleted file mode 100644 index 27ca7d650c46..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/prune/pruner.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.prune.pruner""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.prune import BasePruner -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorflowPrunerFactory: - """Pruner factory for tensorflow""" - - def create(self, base_cls: BasePruner) -> BasePruner: - """Create adaptive pruner - - Parameters - ---------- - base_cls: BasePruner - The base pruner class - - Returns - ------- - pruner_cls: BasePruner - The pruner class. - """ - - @msc_utils.register_tool - class Pruner(base_cls): - """Adaptive pruner for tensorflow""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORFLOW - - return Pruner - - -factory = TensorflowPrunerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py deleted file mode 100644 index ed458ef8381d..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.quantize""" - -from .quantizer import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py deleted file mode 100644 index 1a642c09673a..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.quantize.quantizer""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorflowQuantizerFactory: - """Quantizer factory for tensorflow""" - - def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: - """Create adaptive quantizer - - Parameters - ---------- - base_cls: BaseQuantizer - The base quantizer class - - Returns - ------- - quantizer_cls: BaseQuantizer - The quantizer class. - """ - - @msc_utils.register_tool - class Quantizer(base_cls): - """Adaptive quantizer for tensorflow""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORFLOW - - return Quantizer - - -factory = TensorflowQuantizerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/track/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/track/__init__.py deleted file mode 100644 index e8787fb666a9..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/track/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.track""" - -from .tracker import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py deleted file mode 100644 index d90bb3dd832c..000000000000 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/track/tracker.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorflow.tools.track.tracker""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.tools.track import BaseTracker -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorflowTrackerFactory: - """Tracker factory for tensorflow""" - - def create(self, base_cls: BaseTracker) -> BaseTracker: - """Create adaptive tracker - - Parameters - ---------- - base_cls: BaseTracker - The base tracker class - - Returns - ------- - tracker_cls: BaseTracker - The tracker class. - """ - - @msc_utils.register_tool - class Tracker(base_cls): - """Adaptive tracker for tensorflow""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORFLOW - - return Tracker - - -factory = TensorflowTrackerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/__init__.py deleted file mode 100644 index a1c3d532efc6..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt""" diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py deleted file mode 100644 index a09ab875fbed..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.framework.tensorrt", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py deleted file mode 100644 index 618a178a2d5b..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.codegen""" - -from .codegen import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py deleted file mode 100644 index b60908774dbc..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ /dev/null @@ -1,210 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.codegen.codegen""" - -import os -import subprocess -from typing import Any, Dict, List, Optional, Union - -import numpy as np - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.codegen import CodeGen -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.utils import MSCFramework -from tvm.contrib.msc.framework.tensorrt import _ffi_api - -from .sources import get_trt_sources -from .utils import write_weight - - -def to_sub_tensorrt( - graph: MSCGraph, - weights: Dict[str, tvm.runtime.Tensor], - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - output_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, -) -> str: - """Change MSCGraph to TensorRT engine file. - - Parameters - ---------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The parameters of the IRModule. - codegen_config: dict - The config for codegen. - print_config: dict - The config for print. - build_folder: MSCDirectory - The folder for saving sources and datas. - export_folder: MSCDirectory - The folder for saving outputs. - plugin: PluginManager - The plugin manager. - - Returns - ------- - engine: str - The engine file. - """ - - codegen_config = msc_utils.copy_dict(codegen_config) - codegen_config["version"] = msc_utils.get_version(MSCFramework.TENSORRT) - if "tensorrt_root" not in codegen_config: - codegen_config["tensorrt_root"] = _ffi_api.GetTensorRTRoot() - build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True) - output_folder = output_folder or msc_utils.msc_dir("msc_output") - depends = {} - if "range_file" in codegen_config: - range_file = codegen_config["range_file"] - codegen_config["range_file"] = os.path.basename(range_file) - depends[codegen_config["range_file"]] = {"src": range_file, "copy_back": True} - - def _create_depends(folder: msc_utils.MSCDirectory) -> str: - if weights: - # gather weights - engine_wts = {} - for node in graph.get_nodes(): - for weight in node.get_weights().values(): - engine_wts[weight.name] = weights[weight.name] - if node.optype in ("nn.conv2d", "msc.linear"): - weight = node.weight_at("weight") - bias = np.zeros([weight.dim_at("O")], dtype=weight.dtype_name) - engine_wts[node.name + ".bias"] = bias - # write weights file - with open(folder.relpath(graph.name + ".wts"), "w") as f: - f.write(f"{len(engine_wts)}\n") - for name, data in engine_wts.items(): - write_weight(name, msc_utils.cast_array(data), f) - # copy plugin - if plugin: - plugin.copy_libs("plugin_lib") - plugin.copy_includes("plugin") - # save utils sources - with folder.create_dir("utils") as utils_folder: - for name, source in get_trt_sources().items(): - utils_folder.add_file(name, source) - # copy depends - for path, info in depends.items(): - if os.path.exists(info["src"]): - folder.copy(info["src"], path) - - def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: - with open("engine.log", "w") as log_f: - process = subprocess.Popen("./" + engine_name, stdout=log_f, stderr=log_f, shell=True) - process.wait() - assert process.returncode == 0, ( - f"Failed to test engine {engine_name} under {os.getcwd()}, check engine.log for detail" - ) - for path, info in depends.items(): - if info.get("copy_back", False) and os.path.exists(path): - folder.copy(path, info["src"]) - return folder.move(engine_name + ".trt", output_folder.relpath(engine_name + ".trt")) - - with build_folder as folder: - sub_folder = folder.create_dir(graph.name) - if plugin: - codegen_config["extern_libs"] = [ - sub_folder.create_dir("plugin_lib").relpath(f) for f in plugin.list_libs() - ] - codegen = CodeGen( - graph, - _ffi_api.GetTensorRTSources, - codegen_config, - print_config, - sub_folder, - code_format="cpp", - ) - engine_file = codegen.load([], pre_load=_create_depends, post_load=_build_engine) - return { - "graph_json": graph.to_json(), - "graph_name": graph.name, - "engine": engine_file, - } - - -def to_tensorrt( - mod: tvm.IRModule, - graphs: List[MSCGraph], - weights: Dict[str, tvm.runtime.Tensor], - codegen_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, - print_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, - extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, - build_folder: msc_utils.MSCDirectory = None, - output_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, -) -> Dict[str, str]: - """Change all MSCGraphs to TensorRT engine files. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - graphs: list - The translated graphs. - weights: dict - The weights. - codegen_configs: dict or list - The config for codegen. - print_configs: dict ot list - The config for print. - extra_option: dict - The extra option for sub engine. - build_folder: MSCDirectory - The folder for saving sources and datas. - export_folder: MSCDirectory - The folder for saving outputs. - plugin: PluginManager - The plugin manager. - - Returns - ------- - mod: IRModule - The translated mod with target func. - """ - - target_options = {} - if not isinstance(codegen_configs, (list, tuple)): - codegen_configs = [codegen_configs] * len(graphs) - if not isinstance(print_configs, (list, tuple)): - print_configs = [print_configs] * len(graphs) - if not isinstance(extra_options, (list, tuple)): - extra_options = [extra_options] * len(graphs) - for idx, graph in enumerate(graphs): - options = to_sub_tensorrt( - graph, - weights, - codegen_configs[idx], - print_configs[idx], - build_folder, - output_folder, - plugin=plugin, - ) - if extra_options[idx]: - options.update(extra_options[idx]) - target_options[graph.name] = msc_utils.dump_dict(options) - mod = tvm.transform.Sequential( - [ - tvm.relax.transform.RunCodegen({"msc_tensorrt": target_options}), - ] - )(mod) - return mod diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py deleted file mode 100644 index cbf84eb4c504..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py +++ /dev/null @@ -1,488 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.codegen.sources""" - -from typing import Dict - -from tvm.contrib.msc.core.codegen import get_base_sources - - -def get_trt_common_h_code() -> str: - """Create trt_common header file codes - - Returns - ------- - source: str - The trt_common header source. - """ - - return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ -#define TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ - -#include -#include -#include -#include -#include -#include - -#include "NvInfer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace nvinfer1; - -#ifndef TRT_VERSION_GE -#define TRT_VERSION_GE(major, minor, patch) \\ - ((TRT_MAJOR > major) || (TRT_MAJOR == major && TRT_MINOR > minor) || \\ - (TRT_MAJOR == major && TRT_MINOR == minor && TRT_PATCH >= patch)) -#endif - -#if TRT_VERSION_GE(8, 0, 0) -#define TRT_NOEXCEPT noexcept -#else -#define TRT_NOEXCEPT -#endif - -#define CHECK(status) \\ - do { \\ - auto ret = (status); \\ - if (ret != 0) { \\ - std::cout << "CUDA failure: " << ret << std::endl; \\ - abort(); \\ - } \\ - } while (0) - -class TRTLogger : public ILogger { - public: - TRTLogger() : TRTLogger(Severity::kINFO) {} - explicit TRTLogger(Severity severity) { severity_ = severity; } - void log(Severity severity, const char* msg) noexcept override { - if (severity > severity_) return; - - switch (severity) { - case Severity::kINTERNAL_ERROR: - std::cout << "[MSC.INTERNAL_ERROR]: " << msg << std::endl; - break; - case Severity::kERROR: - std::cout << "[MSC.ERROR]: " << msg << std::endl; - break; - case Severity::kWARNING: - std::cout << "[MSC.WARNING]: " << msg << std::endl; - break; - case Severity::kINFO: - std::cout << "[MSC.INFO]: " << msg << std::endl; - break; - case Severity::kVERBOSE: - std::cout << "[MSC.VERBOSE]: " << msg << std::endl; - break; - default: - std::cout << "[MSC.UNKNOWN]: " << msg << std::endl; - break; - } - } - - void setLogSeverity(Severity severity) { severity_ = severity; } - - private: - Severity severity_; -}; - -struct InferDeleter { - template - void operator()(T* obj) const { - if (obj) { -#if TRT_VERSION_GE(8, 0, 0) - delete obj; -#else - obj->destroy(); -#endif - } - } -}; - -template -using TRTPtr = std::unique_ptr; - -class TRTUtils { - public: - static const std::string TensorInfo(ILayer* layer, size_t id = 0); - - static std::map LoadWeights(const std::string& file); - -#if TRT_VERSION_GE(6, 0, 0) - static bool SerializeEngineToFile(const std::string& file, TRTPtr& builder, - TRTPtr& network, - TRTPtr& config, TRTLogger& logger); -#else - static bool SerializeEngineToFile(const std::string& file, TRTPtr& builder, - TRTPtr& network, TRTLogger& logger); - -#endif - - static bool DeserializeEngineFromFile(const std::string& file, - std::shared_ptr& engine, TRTLogger& logger); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_UTILS_TRT_COMMON_H_ -""" - - -def get_trt_common_cc_code() -> str: - """Create trt_common cc file codes - - Returns - ------- - source: str - The trt_common cc source. - """ - - return """#include "trt_common.h" - -namespace tvm { -namespace contrib { -namespace msc { - -const std::string TRTUtils::TensorInfo(ILayer* layer, size_t id) { - std::string info = "S:"; - Dims dims = layer->getOutput(id)->getDimensions(); - for (int i = 0; i < dims.nbDims; i++) { - info += std::to_string(dims.d[i]) + ';'; - } - DataType dtype = layer->getOutput(id)->getType(); - info += " D:"; - if (dtype == DataType::kFLOAT) { - info += "float32"; - } else if (dtype == DataType::kHALF) { - info += "float16"; - } else if (dtype == DataType::kINT32) { - info += "int32"; - } else if (dtype == DataType::kINT8) { - info += "int8"; - } else if (dtype == DataType::kBOOL) { - info += "bool"; - } else { - info += "unknown"; - } - return info; -} - -std::map TRTUtils::LoadWeights(const std::string& file) { - std::map weightMap; - // Open weights file - std::ifstream input(file, std::ios::binary); - assert(input.is_open() && ("Failed to open file " + file).c_str()); - - // Read number of weight blobs - int32_t count; - input >> count; - assert(count > 0 && "Invalid weight map file."); - std::cout << "Find " << count << " weigths in the file : " << file << std::endl; - - while (count--) { - Weights wt{DataType::kFLOAT, nullptr, 0}; - uint32_t type, size; - // Read name and type of blob - std::string name; - input >> name >> std::dec >> type >> size; - wt.type = static_cast(type); - - // Load blob - if (wt.type == DataType::kFLOAT) { - uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0; x < size; ++x) { - input >> std::hex >> val[x]; - } - wt.values = val; - } else if (wt.type == DataType::kHALF) { - uint16_t* val = reinterpret_cast(malloc(sizeof(val) * size)); - for (uint32_t x = 0; x < size; ++x) { - input >> std::hex >> val[x]; - } - wt.values = val; - } - wt.count = size; - weightMap[name] = wt; - } - input.close(); - return weightMap; -} - -#if TRT_VERSION_GE(6, 0, 0) -bool TRTUtils::SerializeEngineToFile(const std::string& file, TRTPtr& builder, - TRTPtr& network, - TRTPtr& config, TRTLogger& logger) { -#if TRT_VERSION_GE(8, 0, 0) - auto plan = TRTPtr(builder->buildSerializedNetwork(*network, *config)); -#else - auto engine = TRTPtr(builder->buildEngineWithConfig(*network, *config)); - if (!engine) { - logger.log(ILogger::Severity::kERROR, "Failed to build engine"); - return false; - } - auto plan = TRTPtr(engine->serialize()); -#endif - if (!plan) { - logger.log(ILogger::Severity::kERROR, "Failed to serialize network"); - return false; - } - std::ofstream ofs(file, std::ios::out | std::ios::binary); - assert(ofs.is_open() && ("Failed to open file " + file).c_str()); - ofs.write((char*)(plan->data()), plan->size()); - ofs.close(); - return true; -} -#else -bool TRTUtils::SerializeEngineToFile(const std::string& file, TRTPtr& builder, - TRTPtr& network, TRTLogger& logger) { - auto engine = TRTPtr(builder->buildCudaEngine(*network)); - if (!engine) { - logger.log(ILogger::Severity::kERROR, "Failed to build engine"); - return false; - } - auto plan = TRTPtr(engine->serialize()); - if (!plan) { - logger.log(ILogger::Severity::kERROR, "Failed to serialize network"); - return false; - } - std::ofstream ofs(file, std::ios::out | std::ios::binary); - assert(ofs.is_open() && ("Failed to open file " + file).c_str()); - ofs.write((char*)(plan->data()), plan->size()); - ofs.close(); - return true; -} -#endif - -bool TRTUtils::DeserializeEngineFromFile(const std::string& file, - std::shared_ptr& engine, TRTLogger& logger) { - std::vector stream; - size_t size{0}; - std::ifstream input(file, std::ifstream::binary); - assert(input.is_open() && ("Failed to open file " + file).c_str()); - if (input.good()) { - input.seekg(0, input.end); - size = input.tellg(); - input.seekg(0, input.beg); - stream.resize(size); - input.read(stream.data(), size); - input.close(); - } - logger.log(ILogger::Severity::kINFO, - ("size of engine from " + file + " is " + std::to_string(size)).c_str()); - auto runtime = TRTPtr(createInferRuntime(logger)); - engine = std::shared_ptr( - runtime->deserializeCudaEngine(stream.data(), size, nullptr), InferDeleter()); - input.close(); - return true; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm -""" - - -def get_trt_quantize_h_code(): - """Create trt_quantize header file codes - - Returns - ------- - source: str - The trt_quantize header source. - """ - - return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ -#define TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "NvInfer.h" -#include "base.h" -#include "trt_common.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace nvinfer1; - -class CalibrateHelper { - public: - CalibrateHelper(const std::string& range_file, const std::string& folder, int max_size = -1); - - ~CalibrateHelper() { - for (const auto& buffer : cpu_buffers_) { - free(buffer); - } - for (const auto& buffer : gpu_buffers_) { - CHECK(cudaFree(buffer)); - } - } - - bool GetBatch(void* bindings[], const char* names[], int nbBindings); - - const void* ReadCache(size_t& length); - - void WriteCache(const void* cache, size_t length); - - private: - std::unique_ptr reader_; - std::string range_file_; - std::vector cache_; - std::vector cpu_buffers_; - std::vector gpu_buffers_; -}; - -#define CALIBRATE_MEMBERS(Calibrator) \\ - public: \\ - Calibrator(const std::string& range_file, const std::string& folder, int max_size = -1) { \\ - helper_.reset(new CalibrateHelper(range_file, folder, max_size)); \\ - } \\ - \\ - virtual ~Calibrator() {} \\ - \\ - int getBatchSize() const noexcept override { return 1; } \\ - \\ - bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { \\ - return helper_->GetBatch(bindings, names, nbBindings); \\ - } \\ - \\ - const void* readCalibrationCache(size_t& length) noexcept override { \\ - return helper_->ReadCache(length); \\ - } \\ - \\ - void writeCalibrationCache(const void* cache, size_t length) noexcept override { \\ - return helper_->WriteCache(cache, length); \\ - } \\ - \\ - private: \\ - std::unique_ptr helper_; - -class MSCInt8EntropyCalibrator : public IInt8EntropyCalibrator { - CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator) -}; - -class MSCInt8EntropyCalibrator2 : public IInt8EntropyCalibrator2 { - CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator2) -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ -""" - - -def get_trt_quantize_cc_code(): - """Create trt_quantize cc file codes - - Returns - ------- - source: str - The trt_quantize cc source. - """ - - return """#include "trt_quantize.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace nvinfer1; - -CalibrateHelper::CalibrateHelper(const std::string& range_file, const std::string& folder, - int max_size) { - range_file_ = range_file; - reader_.reset(new DatasetReader(folder, max_size)); - const auto& tensor_names = reader_->GetTensorNames(); - cpu_buffers_.resize(tensor_names.size()); - gpu_buffers_.resize(tensor_names.size()); - for (size_t i = 0; i < tensor_names.size(); i++) { - size_t tensor_size = reader_->GetTensorSize(tensor_names[i]); - cpu_buffers_[i] = malloc(tensor_size); - CHECK(cudaMalloc(&gpu_buffers_[i], tensor_size)); - } -} - -bool CalibrateHelper::GetBatch(void* bindings[], const char* names[], int nbBindings) { - if (!reader_->ReadNext(cpu_buffers_.data())) { - return false; - } - for (size_t i = 0; i < nbBindings; i++) { - CHECK(cudaMemcpy(gpu_buffers_[i], cpu_buffers_[i], reader_->GetTensorSize(names[i]), - cudaMemcpyHostToDevice)); - bindings[i] = gpu_buffers_[i]; - } - return true; -} - -const void* CalibrateHelper::ReadCache(size_t& length) { - cache_.clear(); - std::ifstream in_file(range_file_, std::ifstream::binary); - if (!in_file.is_open()) { - return nullptr; - } - in_file >> std::noskipws; - std::copy(std::istream_iterator(in_file), std::istream_iterator(), - std::back_inserter(cache_)); - length = cache_.size(); - return length > 0 ? &cache_[0] : nullptr; -} - -void CalibrateHelper::WriteCache(const void* cache, size_t length) { - std::ofstream output(range_file_, std::ios::binary); - output.write(reinterpret_cast(cache), length); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm -""" - - -def get_trt_sources() -> Dict[str, str]: - """Create trt sources for cpp codegen - - Returns - ------- - sources: dict - The trt utils sources. - """ - - sources = get_base_sources() - sources.update( - { - "trt_common.h": get_trt_common_h_code(), - "trt_common.cc": get_trt_common_cc_code(), - "trt_quantize.h": get_trt_quantize_h_code(), - "trt_quantize.cc": get_trt_quantize_cc_code(), - } - ) - return sources diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py deleted file mode 100644 index 11b0f7a5444b..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.codegen.utils""" - -import io -import struct - -import numpy as np - - -def enum_dtype(array: np.ndarray) -> int: - """Get TensorRT DType enum from array. - - Parameters - ---------- - array: np.ndarray - The source array. - - Returns - ------- - dtype: int - The dtype enum. - """ - - if array.dtype == np.float32: - return 0 - if array.dtype == np.float16: - return 1 - if array.dtype == np.int8: - return 2 - if array.dtype == np.int32: - return 3 - raise Exception(f"Unexpected dtype {array.dtype}, no matching tensorrt dtype") - - -def float_to_hex(value: float) -> str: - """Change float to hex. - - Parameters - ---------- - value: float - The float value. - - Returns - ------- - hex: str - The hex format string. - """ - - return hex(struct.unpack(" str: - """Change array to hex. - - Parameters - ---------- - array: np.ndarray - The source array. - - Returns - ------- - hex: str - The hex format string. - """ - - return " ".join([float_to_hex(float(f))[2:] for f in array.flatten()]) - - -def write_weight(name: str, weight: np.ndarray, f_handler: io.TextIOWrapper): - """Write array to file in TensorRT format. - - Parameters - ---------- - name: str - The array name - weight: np.ndarray - The weight data. - f_handler: io.TextIOWrapper - The file handler - """ - - f_handler.write(f"{name} {enum_dtype(weight)} {weight.size} {array_to_hex(weight)}\n") diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py deleted file mode 100644 index f91719511519..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/frontend/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.frontend""" - -from .translate import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py deleted file mode 100644 index 7f7e081622b2..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.frontend.translate""" - -from typing import Dict, List, Optional, Tuple - -import tvm -from tvm import relax -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core.frontend import byoc_partition -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.framework.tensorrt import transform as trt_transform - - -def transform_for_tensorrt( - mod: tvm.IRModule, - trans_config: Optional[Dict[str, str]] = None, -) -> tvm.IRModule: - """Transform module to tensorrt. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - trans_config: dict - The config for transform IRModule. - - Returns - ------- - mod: IRModule - The transformed IRModule of relax. - """ - - trans_config = trans_config or {} - return tvm.transform.Sequential( - [ - msc_transform.SetExprName(), - trt_transform.TransformTensorRT( - version=trans_config.get("version"), - linear_to_conv=trans_config.get("linear_to_conv", False), - ), - relax.transform.FoldConstant(), - ] - )(mod) - - -def partition_for_tensorrt( - mod: tvm.IRModule, - params: Optional[Dict[str, tvm.runtime.Tensor]] = None, - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, -) -> Tuple[tvm.IRModule, List[Tuple[MSCGraph, Dict[str, tvm.runtime.Tensor]]]]: - """Partition module to tensorrt sub functions. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - trans_config: dict - The config for transform IRModule. - params: dict of - The parameters of the IRModule. - build_config: dict - The config for build MSCGraph. - - Returns - ------- - mod: IRModule - The IRModule of partitioned relax. - graphs_info: list<> - The func list, each element for a sub graph. - """ - - mod = transform_for_tensorrt(mod, trans_config) - return byoc_partition("msc_tensorrt", mod, params, trans_config, build_config) diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py deleted file mode 100644 index 56203292e418..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.runtime""" - -from .runner import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py deleted file mode 100644 index a1c1ec4966e8..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ /dev/null @@ -1,189 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -# ruff: noqa: F401 -"""tvm.contrib.msc.framework.tensorrt.runtime.runner""" - -import os -from typing import Any, Dict, List - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.runtime import BYOCRunner -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.tensorrt import tools -from tvm.contrib.msc.framework.tensorrt.codegen import to_tensorrt -from tvm.contrib.msc.framework.tensorrt.frontend import ( - partition_for_tensorrt, - transform_for_tensorrt, -) - - -class TensorRTRunner(BYOCRunner): - """Runner of tensorrt""" - - def setup(self) -> dict: - """Setup the runner - - Returns - ------- - info: dict - The setup info. - """ - - if not self._device.startswith("cuda"): - self._device = "cuda" - assert not self._training, "TensorRT only support eval" - return super().setup() - - def train(self): - """Change status to train""" - - raise Exception("TensorRT only support eval") - - def make_plan(self, tool_type: str, data_loader: Any = None) -> dict: - """Execute tool and get plan - - Parameters - ------- - tool_type: str - The tool type, should be in ToolType - data_loader: - The data loader - """ - - assert tool_type in self._tools, "Can not find tool " + str(tool_type) - if tool_type == ToolType.QUANTIZER: - quantizer = self.get_tool(ToolType.QUANTIZER) - assert data_loader, "data_loader should be given to plan prune" - for inputs in data_loader(): - self.run(inputs) - self._generate_model(self._graphs, self._weights) - quantizer.calibrate() - assert quantizer.calibrated, "Failed to calibrate the tenosrrt quantizer" - return super().make_plan(tool_type, data_loader) - - def _generate_model( - self, graphs: List[MSCGraph], weights: Dict[str, tvm.runtime.Tensor] - ) -> Any: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: dict - The weights. - - Returns - ------- - model: Any - The meta model - """ - - codegen = self._generate_config.get("codegen") - if not isinstance(codegen, (list, tuple)): - self._generate_config["codegen"] = [msc_utils.copy_dict(codegen)] * len(self._graphs) - for tool in self.get_tools(): - self._generate_config = tool.config_generate(self._generate_config) - - return super()._generate_model(graphs, weights) - - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the runnable - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The runnable info. - """ - - def _get_engine(graph: MSCGraph) -> str: - engine_file = msc_utils.get_output_dir().relpath(graph.name + ".trt") - assert os.path.isfile(engine_file), "Missing engine file " + engine_file - return engine_file - - info = super().export_runnable(folder) - info["engines"] = {g.name: _get_engine(g) for g in self._graphs} - return info - - @classmethod - def target_transform(cls, mod: tvm.IRModule): - """Transform the mod by target. - - Parameters - ---------- - mod: IRModule - The IRModule of relax. - - Returns - ------- - mod: IRModule - The IRModule of partitioned relax. - """ - - return transform_for_tensorrt(mod) - - @property - def codegen_func(self): - return to_tensorrt - - @property - def partition_func(self): - return partition_for_tensorrt - - @property - def framework(self): - return MSCFramework.TENSORRT - - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = BYOCRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage in (MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE): - run_config = config[stage].get("run_config", {}) - if "extra_option" not in run_config["generate_config"]: - run_config["generate_config"]["extra_option"] = {} - run_config["generate_config"]["extra_option"]["stage"] = stage - config[stage]["run_config"] = run_config - return config diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py deleted file mode 100644 index 2ca27ebed83e..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools""" - -from .prune import * -from .quantize import * -from .distill import * -from .track import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py deleted file mode 100644 index 4d14e35c4151..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools.distill""" - -from .distiller import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py deleted file mode 100644 index 5a93e0f6c7dd..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/distill/distiller.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools.distill.distiller""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.distill import BaseDistiller -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorRTDistillerFactory: - """Distiller factory for tensorrt""" - - def create(self, base_cls: BaseDistiller) -> BaseDistiller: - """Create adaptive distiller - - Parameters - ---------- - base_cls: BaseDistiller - The base distiller class - - Returns - ------- - distiller_cls: BaseDistiller - The distiller class. - """ - - @msc_utils.register_tool - class Distiller(base_cls): - """Adaptive distiller for tensorrt""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORRT - - return Distiller - - -factory = TensorRTDistillerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py deleted file mode 100644 index 24ef6a62b24b..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools.prune""" - -from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py deleted file mode 100644 index 087c7a005ad0..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/prune/pruner.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools.prune.pruner""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.prune import BasePruner -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorRTPrunerFactory: - """Pruner factory for tensorrt""" - - def create(self, base_cls: BasePruner) -> BasePruner: - """Create adaptive pruner - - Parameters - ---------- - base_cls: BasePruner - The base pruner class - - Returns - ------- - pruner_cls: BasePruner - The pruner class. - """ - - @msc_utils.register_tool - class Pruner(base_cls): - """Adaptive pruner for tensorrt""" - - @classmethod - def framework(cls): - return MSCFramework.TENSORRT - - return Pruner - - -factory = TensorRTPrunerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py deleted file mode 100644 index 5dc586aeff4a..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.tools.quantize""" - -from .quantizer import * -from .method import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py deleted file mode 100644 index e506cb4fd1e9..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py +++ /dev/null @@ -1,147 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.framework.tensorrt.tools.quantize.method""" - -from typing import Dict - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeMethod -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class TensorRTQuantizeMethod(QuantizeMethod): - """Default quantize method for tensorrt""" - - @classmethod - def quantize_normal( - cls, - quantizer: BaseQuantizer, - tensor_ctx: Dict[str, str], - name: str, - consumer: str, - scale: float, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> Dict[str, str]: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - tensor_ctx: dict - Tensor describe items. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - tensor_ctx: dict - Tensor describe items. - """ - - if quantizer.is_weight(name): - return tensor_ctx - dtype = quantizer.find_tensor(name).dtype_name - precision = "DataType::k" - if nbits == 8: - precision += "INT8" - elif dtype == "float16": - precision += "HALF" - elif dtype == "float32": - precision += "FLOAT" - else: - raise TypeError(f"nbits {nbits} is not supported") - tensor_ctx["processed"].extend( - [ - "{}->setPrecision({})".format(tensor_ctx["producer"], precision), - "{0}->setDynamicRange(-{1}, {1})".format(tensor_ctx["tensor"], scale), - ] - ) - return tensor_ctx - - @classmethod - def dequantize_normal( - cls, - quantizer: BaseQuantizer, - tensor_ctx: Dict[str, str], - name: str, - consumer: str, - scale: float, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> Dict[str, str]: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - tensor_ctx: dict - Tensor describe items. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - tensor_ctx: dict - Tensor describe items. - """ - - return cls.quantize_normal( - quantizer, tensor_ctx, name, consumer, scale, nbits, axis, sign, rounding, epsilon - ) - - @classmethod - def framework(cls): - return MSCFramework.TENSORRT diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py deleted file mode 100644 index f0a0cbd978d6..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py +++ /dev/null @@ -1,360 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501 -"""tvm.contrib.msc.framework.tensorrt.tools.quantize.quantizer""" - -import os -import struct -from typing import Any, Dict, List, Tuple - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeStage -from tvm.contrib.msc.core.tools.tool import ToolStrategy, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TensorRTQuantizerFactory: - """Quantizer factory for tensorrt""" - - def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: - """Create adaptive quantizer - - Parameters - ---------- - base_cls: BaseQuantizer - The base quantizer class - - Returns - ------- - quantizer_cls: BaseQuantizer - The quantizer class. - """ - - @msc_utils.register_tool - class Quantizer(base_cls): - """Adaptive quantizer for tensorrt""" - - def setup(self) -> dict: - """Setup the tool - - Returns - ------- - info: dict - The setup info. - """ - - if self._plan: - self._use_range = all( - info.get("use_range", False) for info in self._plan.values() - ) - else: - self._use_range = True - return super().setup() - - def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.runtime.Tensor]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.runtime.Tensor]]]: - """Reset the tool - - Parameters - ---------- - graphs: list - The msc graphs. - weights: list> - The weights - - Returns - ------- - graphs: list - The msc graphs. - weights: list> - The weights - """ - - config_folder = msc_utils.get_config_dir() - self._range_files = [config_folder.relpath(g.name + ".range") for g in graphs] - calibrate_root = msc_utils.get_dataset_dir().create_dir("Calibrate") - self._calibrate_folders = [calibrate_root.relpath(g.name) for g in graphs] - if self._calibrated: - if self._use_range: - for r_file, graph in zip(self._range_files, graphs): - if not os.path.isfile(r_file): - self._plan_to_range(graph, r_file) - self._logger.debug( - "G[%s](%s) use range file: %s", - graph.name, - self._stage, - r_file, - ) - else: - self._quantized_tensors = set() - elif self._stage == QuantizeStage.GATHER: - self._calibrate_savers = [] - for folder, graph in zip(self._calibrate_folders, graphs): - saver_options = {"input_names": [i.name for i in graph.get_inputs()]} - saver = msc_utils.IODataSaver(folder, saver_options) - self._calibrate_savers.append(saver) - self._logger.debug( - "G[%s](%s) create calibrate saver: %s", - graph.name, - self._stage, - saver, - ) - else: - assert all(msc_utils.is_io_dataset(f) for f in self._calibrate_folders), ( - "Some IODataset missing: " + str(self._calibrate_folders) - ) - return super()._reset(graphs, weights) - - def _execute_after_build(self, codegen_context: dict) -> dict: - """Execute after model build - - Parameters - ---------- - codegen_context: dict - The context. - - Returns - ---------- - codegen_context: dict - The processed context. - """ - - if self._stage == QuantizeStage.GATHER and self._forward_cnt == 0: - return codegen_context - if not self._use_range: - return codegen_context - processed = ["// Set int8 calibrator"] - range_file = self.get_graph().name + ".range" - version = [int(v) for v in codegen_context["version"].split(".")] - if msc_utils.compare_version(version, [6, 0, 0]) >= 0: - configer = codegen_context["config"] - else: - configer = codegen_context["builder"] - # check the range file if calibrated - if self._calibrated: - processed.extend( - [ - f'if (!FileUtils::FileExist("{range_file}")) {{', - f' logger.log(ILogger::Severity::kERROR, "{range_file} not exist!");', - " return -1;", - "}", - ] - ) - processed.extend( - [ - f'MSCInt8EntropyCalibrator2 calibrator("{range_file}", "{self._calibrate_folders[self._graph_id]}");', - f"{configer}->setInt8Calibrator(&calibrator);", - ] - ) - codegen_context["processed"].extend(processed) - return codegen_context - - def _execute_before_forward(self, step_context: dict) -> dict: - """Execute before model forward - - Parameters - ---------- - step_context: dict - The context. - - Returns - ---------- - step_context: dict - The processed context. - """ - - if self._stage == QuantizeStage.GATHER: - saver = self._calibrate_savers[self._graph_id] - saver.save_batch( - {name: data.numpy() for name, data in step_context["datas"].items()} - ) - for name, data in step_context["datas"].items(): - self.debug_tensors(name, "any", "ctx_gather", {"gather": data}) - super()._execute_before_forward(step_context) - - def _quantize_tensor( - self, - tensor_ctx: Dict[str, str], - name: str, - consumer: str, - strategys: List[ToolStrategy], - ) -> Dict[str, str]: - """Quantize tensor - - Parameters - ------- - tensor_ctx: dict - Tensor describe items. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor_ctx: dict - Tensor items with processed. - """ - - if not self._use_range and name not in self._quantized_tensors: - self._quantized_tensors.add(name) - return super()._quantize_tensor(tensor_ctx, name, consumer, strategys) - return tensor_ctx - - def calibrate(self) -> dict: - """Calibrate the datas - - Returns - ------- - plan: dict - The calibrated plan. - """ - - for r_file, graph in zip(self._range_files, self._graphs): - self._range_to_plan(graph, r_file) - self._calibrated, self._forward_cnt = True, 0 - self.change_stage("quantize") - return self._plan - - def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: - """Update the generate configs - - Parameters - ---------- - generate_config: dict - The generate_config. - - Returns - ------- - generate_config: dict - The updated generate_config. - """ - - if self._calibrated: - if self._use_range: - for config, r_file in zip(generate_config["codegen"], self._range_files): - if os.path.isfile(r_file): - config.update({"range_file": r_file, "precision": "int8"}) - elif self._stage == QuantizeStage.GATHER and self._forward_cnt > 0: - for config, saver, r_file in zip( - generate_config["codegen"], self._calibrate_savers, self._range_files - ): - saver.finalize() - msg = f"Save {self._forward_cnt} batch to {saver.folder}" - self._logger.debug(self.msg_mark(msg, in_forward=False)) - config.update( - {"dataset": saver.folder, "range_file": r_file, "precision": "int8"} - ) - self.change_stage(QuantizeStage.CALIBRATE) - return generate_config - - def _plan_to_range(self, graph: MSCGraph, range_file: str, title="MSCCalibrate"): - """Extract plan config to range_file - - Parameters - ---------- - plan: dict - The plan. - graph: MSCGraph - The graph. - range_file: str - The output range_file path. - title: str - The title of the range file. - """ - - def _scale_to_hex(scale): - return hex(struct.unpack(" BaseTracker: - """Create adaptive tracker - - Parameters - ---------- - base_cls: BaseTracker - The base tracker class - - Returns - ------- - tracker_cls: BaseTracker - The tracker class. - """ - - @msc_utils.register_tool - class Tracker(base_cls): - """Adaptive tracker for tensorrt""" - - def _execute_before_build(self, codegen_context: dict) -> dict: - """Execute before model build - - Parameters - ---------- - codegen_context: dict - The context. - - Returns - ---------- - codegen_context: dict - The processed context. - """ - - self._track_tensors = {} - super()._execute_before_build(codegen_context) - - def _execute_before_forward(self, step_context: dict) -> dict: - """Execute before model forward - - Parameters - ---------- - step_context: dict - The context. - - Returns - ---------- - step_context: dict - The processed context. - """ - - for name, data in step_context["datas"].items(): - if name not in self._track_tensors: - continue - consumer = self._track_tensors[name]["consumer"] - strategys = self._get_tensor_strategys(name, consumer) - self._track_tensor(data.numpy(), name, consumer, strategys) - return super()._execute_before_forward(step_context) - - def _execute_after_forward(self, step_context: dict) -> dict: - """Execute after model forward - - Parameters - ---------- - step_context: dict - The context. - - Returns - ---------- - step_context: dict - The processed context. - """ - - for name, data in step_context["datas"].items(): - if name not in self._track_tensors: - continue - consumer = self._track_tensors[name]["consumer"] - strategys = self._get_tensor_strategys(name, consumer) - self._track_tensor(data.numpy(), name, consumer, strategys) - return super()._execute_after_forward(step_context) - - def _process_tensor( - self, - tensor_ctx: Dict[str, str], - name: str, - consumer: str, - scope: str, - strategys: List[ToolStrategy], - ) -> Dict[str, str]: - """Process tensor - - Parameters - ------- - tensor_ctx: dict - Tensor describe items. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor_ctx: dict - Tensor items with processed. - """ - - if self.is_weight(name): - return self._track_tensor(self.get_data(name), name, consumer, strategys) - if name not in self._track_tensors: - self._track_tensors[name] = { - "consumer": consumer, - } - tensor_ctx["processed"].append( - "{}->markOutput(*{});".format(tensor_ctx["ctx"], tensor_ctx["tensor"]) - ) - return tensor_ctx - - @classmethod - def framework(cls): - return MSCFramework.TENSORRT - - return Tracker - - -factory = TensorRTTrackerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py deleted file mode 100644 index 7bf054f5461b..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tensorrt.transform""" - -from .pattern import * -from .transform import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py deleted file mode 100644 index 726303b7e5e7..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ /dev/null @@ -1,474 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.framework.tensorrt.transform.pattern""" - -from functools import partial, wraps -from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union - -import tvm -from tvm import relax -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core.transform import pattern as msc_pattern -from tvm.relax.backend.pattern_registry import register_patterns -from tvm.relax.dpl import pattern -from tvm.relax.transform import FusionPattern, PatternCheckContext - - -def basic_pattern( - op_name: str, input_types: Optional[List[str]] = None -) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: - """create basic pattern for tensorrt support ops. - - Parameters - ---------- - op_name: str - The name of a Relax op, such as "relax.nn.conv2d" - input_types: list - The input types, elach element can be input| constant - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing the operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - input_types = input_types or ["input"] - inputs = [] - for i_type in input_types: - if i_type == "input": - inputs.append(pattern.wildcard()) - elif i_type == "constant": - inputs.append(pattern.is_const()) - else: - raise Exception("Unexpected input type " + str(i_type)) - out = pattern.is_op(op_name)(*inputs) - annotations = {"input_" + str(idx): arg for idx, arg in enumerate(inputs)} - annotations["out"] = out - return out, annotations - - -def elemwise_pattern(op_name: str) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: - """create elemwise pattern for tensorrt support ops. - - Parameters - ---------- - op_name: str - The name of a Relax op, such as "relax.add" - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing the operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - return basic_pattern(op_name, ["input", "input"]) - - -def argmaxmin_pattern(op_name: str) -> Tuple[pattern.DFPattern, Mapping[str, pattern.DFPattern]]: - """create argmaxmin pattern for tensorrt support ops. - - Parameters - ---------- - op_name: str - The name of a Relax op, such as "relax.argmax" - - Returns - ------- - out: tvm.relax.dpl.pattern.DFPattern - The resulting pattern describing the operation. - - annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] - A mapping from name to sub pattern. It can be used to extract - important expressions from match result, to power the partition - check function and codegen. - """ - - data = pattern.wildcard() - argmaxmin = pattern.is_op(op_name)(data) - out = pattern.is_op("relax.astype")(argmaxmin) - return out, {"input": data, "argmaxmin": argmaxmin, "out": out} - - -def _check_expr(expr: relax.Expr, dtypes: Optional[Tuple[str]] = None) -> bool: - """Check if the expr can be fused on tensorrt. - - Parameters - ---------- - expr: relax.Expr - The expr to be check - dtype: tuple - The accept dtypes - - Returns - ------- - pass: bool - Whether the expr is correct. - """ - - if isinstance(expr, relax.ShapeExpr): - return True - if isinstance(expr, relax.PrimValue): - return True - if isinstance(expr, relax.Tuple): - return all(_check_expr(field) for field in expr.fields) - dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool") - - def _check(sinfo): - if not sinfo.shape or sinfo.dtype not in dtypes: - return False - unknown_dim = 0 - for s in sinfo.shape.values: - if isinstance(s, tvm.tir.IntImm) and s < 0: - unknown_dim += 1 - return unknown_dim <= 1 - - if isinstance(expr.struct_info, relax.TupleStructInfo): - return all(_check(s) for s in expr.struct_info.fields) - return _check(expr.struct_info) - - -def _basic_check(context: PatternCheckContext) -> bool: - """Check if the basic pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - for _, expr in context.annotated_expr.items(): - if not _check_expr(expr): - return False - return True - - -def _argmaxmin_check(context: PatternCheckContext) -> bool: - """Check if the argmaxmin pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if not _check_expr(context.annotated_expr["input"]): - return False - return _check_expr(context.annotated_expr["out"], ("int32")) - - -def _compare_check(context: PatternCheckContext) -> bool: - """Check if the compare pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "input_1"]): - return False - if not _check_expr(context.annotated_expr["out"], ("bool")): - return False - ndim_a = len(context.annotated_expr["input_0"].struct_info.shape.values) - ndim_b = len(context.annotated_expr["input_1"].struct_info.shape.values) - return ndim_a == ndim_b - - -def _elemwise_check(context: PatternCheckContext) -> bool: - """Check if the elemwise pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if not _basic_check(context): - return False - ndim_a = len(context.annotated_expr["input_0"].struct_info.shape.values) - ndim_b = len(context.annotated_expr["input_1"].struct_info.shape.values) - return ndim_a == ndim_b - - -def _reshape_check(context: PatternCheckContext) -> bool: - """Check if the reshape pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]): - return False - return True - - -def _take_check(context: PatternCheckContext) -> bool: - """Check if the take pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]): - return False - return _check_expr(context.annotated_expr["input_1"], ("int32")) - - -def _plugin_check(context: PatternCheckContext) -> bool: - """Check if the plugin pattern is correct. - - Returns - ------- - pass: bool - Whether the pattern is correct. - """ - - ext_func = context.annotated_expr["out"].args[0] - return bool(_ffi_api.IsPlugin(ext_func.global_symbol)) - - -def plugin_attrs_getter( - annotated_expr: Dict[str, tvm.relax.Expr], -) -> Dict[str, str]: - """Get attributes for plugin pattern - - Parameters - ---------- - annotated_expr: dict - The annotated exprs during fus pattern - anchor: str - The anchor key of expr - - Returns - ------- - attrs: dict - The extra attributes for msc. - """ - - attrs = msc_pattern.msc_attrs_getter(annotated_expr, anchor="out") - ext_func = annotated_expr["out"].args[0] - attrs[_ffi_api.ToAttrKey("optype")] = ext_func.global_symbol - return attrs - - -def wrap_basic_check( - func: Callable[[PatternCheckContext], bool], -) -> Callable[[PatternCheckContext], bool]: - """Wrapper a checker with basic check - - Returns - ------- - checker: PatternCheckContext - The wrapped checker. - """ - - @wraps(func) - def wrapper(context): - if not _basic_check(context): - return False - return func(context) - - return wrapper - - -CheckFunc = Callable[[Mapping[pattern.DFPattern, relax.Expr], relax.Expr], bool] -GetterFunc = Callable[[Mapping[pattern.DFPattern, relax.Expr], relax.Expr], Dict[str, str]] -Pattern = Union[ - FusionPattern, - Tuple[str, pattern.DFPattern], - Tuple[str, pattern.DFPattern, Mapping[str, pattern.DFPattern]], - Tuple[str, pattern.DFPattern, Mapping[str, pattern.DFPattern], CheckFunc], - Tuple[str, pattern.DFPattern, Mapping[str, pattern.DFPattern], CheckFunc, GetterFunc], -] - - -def get_patterns(target) -> List[Pattern]: - """Get all the tensorrt patterns. - - Parameters - ---------- - target: str - The target name for tensorrt patterns. - - Returns - ------- - patterns: list - The patterns - """ - - basic_ops = { - "nn.adaptive_avg_pool2d": ["input"], - "nn.avg_pool2d": ["input"], - "nn.conv2d": ["input", "constant"], - "nn.max_pool2d": ["input"], - "astype": ["input"], - "concat": ["input"], - "clip": ["input", "input", "input"], - "image.resize2d": ["input", "input"], - "matmul": ["input", "input"], - "permute_dims": ["input"], - "strided_slice": ["input", "input", "input", "input", "input"], - "topk": ["input"], - } - activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"] - reduce_ops = ["max", "min", "mean", "sum"] - unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] - elemwise_ops = [ - "add", - "divide", - "floor_divide", - "maximum", - "minimum", - "multiply", - "power", - "subtract", - ] - compare_ops = ["greater", "less"] - patterns = [] - # basic ops - for op, in_types in basic_ops.items(): - inputs = ["input_" + str(i) for i in range(len(in_types))] - patterns.append( - ( - target + "." + op, - *basic_pattern("relax." + op, in_types), - _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=inputs), - ) - ) - # activation ops - for op in activation_ops: - patterns.append( - ( - target + "." + op, - *basic_pattern("relax." + op, ["input"]), - _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), - ) - ) - # reduce ops - for op in reduce_ops: - patterns.append( - ( - target + "." + op, - *basic_pattern("relax." + op, ["input"]), - _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), - ) - ) - # unary ops - for op in unary_ops: - patterns.append( - ( - target + "." + op, - *basic_pattern("relax." + op, ["input"]), - _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), - ) - ) - # elemwise ops - for op in elemwise_ops: - patterns.append( - ( - target + "." + op, - *elemwise_pattern("relax." + op), - _elemwise_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), - ) - ) - # compare ops - for op in compare_ops: - patterns.append( - ( - target + "." + op, - *elemwise_pattern("relax." + op), - _compare_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), - ) - ) - - # special ops - patterns.extend( - [ - ( - target + ".take", - *basic_pattern("relax.take", ["input", "input"]), - _take_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), - ), - ( - target + ".argmax", - *argmaxmin_pattern("relax.argmax"), - _argmaxmin_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]), - ), - ( - target + ".argmin", - *argmaxmin_pattern("relax.argmin"), - _argmaxmin_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]), - ), - ( - target + ".reshape", - *basic_pattern("relax.reshape", ["input", "input"]), - _reshape_check, - partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), - ), - ] - ) - # fusable ops - patterns.extend( - [ - ( - target + ".msc.conv2d_bias", - *msc_pattern.make_opt_relax_conv_bias_pattern("relax.nn.conv2d"), - wrap_basic_check(msc_pattern._check_opt_relax_conv_bias), - partial( - msc_pattern.msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"] - ), - ), - ] - ) - # plugin ops - patterns.append( - ( - target + ".plugin", - *basic_pattern("relax.call_dps_packed", ["input", "input"]), - _plugin_check, - plugin_attrs_getter, - ) - ) - - return patterns - - -register_patterns(get_patterns("msc_tensorrt")) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py deleted file mode 100644 index c92df6c4bf57..000000000000 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py +++ /dev/null @@ -1,49 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""tvm.contrib.msc.framework.tensorrt.transform.transform""" - -from typing import List, Optional - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils import MSCFramework -from tvm.relax.transform import _ffi_api as relax_api - - -def TransformTensorRT( - version: Optional[List[int]] = None, linear_to_conv: bool = False -) -> tvm.ir.transform.Pass: - """Transform the Function to fit TensorRT. - - Parameters - ---------- - version: list - The tensorrt version. - linear_to_conv: bool - Whether to cast linear to conv2d - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - - config = { - "version": version or msc_utils.get_version(MSCFramework.TENSORRT), - "linear_to_conv": linear_to_conv, - } - return relax_api.TransformTensorRT(msc_utils.dump_dict(config)) # type: ignore diff --git a/python/tvm/contrib/msc/framework/torch/__init__.py b/python/tvm/contrib/msc/framework/torch/__init__.py deleted file mode 100644 index 1e211d8d0bd9..000000000000 --- a/python/tvm/contrib/msc/framework/torch/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch""" diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py deleted file mode 100644 index d1f27a53bdcf..000000000000 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.framework.torch", __name__) diff --git a/python/tvm/contrib/msc/framework/torch/codegen/__init__.py b/python/tvm/contrib/msc/framework/torch/codegen/__init__.py deleted file mode 100644 index 9b56c3181589..000000000000 --- a/python/tvm/contrib/msc/framework/torch/codegen/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.codegen""" - -from .codegen import * diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py deleted file mode 100644 index e2530ba206ef..000000000000 --- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.codegen.codegen""" - -from typing import Any, Dict, Optional - -import torch - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.codegen import CodeGen -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.framework.torch import _ffi_api - - -def to_torch( - graph: MSCGraph, - weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, -) -> torch.nn.Module: - """Change MSCGraph to torch nn.Module. - - Parameters - ---------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The parameters of the IRModule. - codegen_config: dict - The config for codegen. - print_config: dict - The config for print. - build_folder: MSCDirectory - The folder for saving scripts and datas. - plugin: PluginManager - The plugin manager. - - Returns - ------- - model: torch.nn.Module - The torch.nn.Module. - """ - - def _save_weights(folder: msc_utils.MSCDirectory): - if weights: - state_dict = {} - for name, data in weights.items(): - w_producer = graph.find_producer(name) - if w_producer.optype == "constant" and w_producer.has_attr("scalar"): - continue - w_tensor = graph.find_tensor(name) - w_name = w_tensor.alias or name - state_dict[w_name] = torch.from_numpy(data.numpy()) - torch.save(state_dict, folder.relpath(graph.name + ".pth")) - - def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> torch.nn.Module: - if weights: - state_dict = torch.load(folder.relpath(graph.name + ".pth"), weights_only=False) - model.load_state_dict(state_dict) - return model - - codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config, print_config, build_folder) - model_args = [plugin] if plugin else [] - return codegen.load(model_args, pre_load=_save_weights, post_load=_bind_weights) diff --git a/python/tvm/contrib/msc/framework/torch/frontend/__init__.py b/python/tvm/contrib/msc/framework/torch/frontend/__init__.py deleted file mode 100644 index 5572720a6980..000000000000 --- a/python/tvm/contrib/msc/framework/torch/frontend/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.frontend""" - -from .translate import * diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py deleted file mode 100644 index 7c5489424493..000000000000 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ /dev/null @@ -1,109 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.frontend.translate""" - -from typing import Dict, List, Optional, Tuple, Union - -import torch - -import tvm -from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs -from tvm.contrib.msc.core.ir.graph import MSCGraph -from tvm.relax.frontend.torch import from_fx - - -def set_weight_alias(graph: MSCGraph) -> MSCGraph: - """Set weight with alias in MSCGraph. - - Parameters - ---------- - graph: MSCGraph - The graph. - - - Returns - ------- - graph: MSCGraph - The graph with weight alias. - """ - - for node in graph.get_nodes(): - for ref, weight in node.get_weights().items(): - if node.optype == "constant": - alias = node.name.replace(".", "_") - elif node.optype in ("nn.batch_norm", "nn.layer_norm", "nn.group_norm"): - if ref == "gamma": - alias = node.name.replace(".", "_") + ".weight" - elif ref == "beta": - alias = node.name.replace(".", "_") + ".bias" - elif ref == "mean": - alias = node.name.replace(".", "_") + ".running_mean" - elif ref == "var": - alias = node.name.replace(".", "_") + ".running_var" - else: - alias = node.name.replace(".", "_") + "." + ref - graph.set_tensor_alias(weight, alias) - return graph - - -def from_torch( - model: torch.nn.Module, - input_info: List[Tuple[Tuple[int], str]], - trans_config: Optional[Dict[str, str]] = None, - build_config: Optional[Dict[str, str]] = None, - as_msc: bool = True, - custom_convert_map: Optional[dict] = None, -) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.runtime.Tensor]]: - """Change torch nn.Module to MSCGraph. - - Parameters - ---------- - model: torch.nn.Module - The torch module. - input_info: list - The input info in format [(shape, dtype)]. - input_names: list - The input names. - trans_config: dict - The config for transform IRModule. - build_config: dict - The config for build MSCGraph. - opt_config: dict - The config for optimize before translate. - as_msc: bool - Set to to return msc graph, otherwise relax mod - custom_convert_map: dict - The convert map for plugin - build_folder: MSCDirectory - The folder for saving scripts and datas. - - Returns - ------- - graph/mod: tvm.contrib.msc.core.ir.MSCGraph/tvm.IRModule - The translated graph/IRModule. - weights: dict of - The weights from the IRModule. - """ - - graph_model = torch.fx.symbolic_trace(model) - input_info, params = normalize_inputs(input_info), None - with torch.no_grad(): - relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) - if not as_msc: - return relax_mod, params - graph, weights = from_relax(relax_mod, trans_config=trans_config, build_config=build_config) - return set_weight_alias(graph), weights diff --git a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py b/python/tvm/contrib/msc/framework/torch/runtime/__init__.py deleted file mode 100644 index b58184fb059b..000000000000 --- a/python/tvm/contrib/msc/framework/torch/runtime/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.runtime""" - -from .runner import * -from .jit import * diff --git a/python/tvm/contrib/msc/framework/torch/runtime/jit.py b/python/tvm/contrib/msc/framework/torch/runtime/jit.py deleted file mode 100644 index 36d6db4797ba..000000000000 --- a/python/tvm/contrib/msc/framework/torch/runtime/jit.py +++ /dev/null @@ -1,217 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -"""tvm.contrib.msc.framework.torch.runtime.jit_model""" - -from functools import partial -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import _dynamo as dynamo -from torch import fx - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.runtime import BaseJIT -from tvm.contrib.msc.core.utils.namespace import MSCFramework - -from .runner import TorchRunner - - -class TorchJIT(BaseJIT): - """JIT of Torch""" - - def _call_jit(self, inputs: Dict[str, Any]) -> Any: - """Run the jit model - - Parameters - ---------- - inputs: - The inputs of model. - """ - - torch_inputs = [ - msc_utils.cast_array(inputs[i], MSCFramework.TORCH, self._device) for i in self._inputs - ] - return self._jit_model(*torch_inputs) - - def _build(self, model: Any) -> Any: - """Build the jit model - - Parameters - ---------- - model: - The model. - - Returns - ------- - jit_model: - The jit model. - """ - - # pylint: disable=unused-argument - def _compile(graph_module: fx.GraphModule, example_inputs): - graph_module = graph_module.train() if self._training else graph_module.eval() - name = "jit_" + str(len(self._runner_ctxs)) - self._runner_ctxs[name] = {"model": graph_module} - return partial(self._redirect_run, runner_name=name) - - dynamo.reset() - return torch.compile(self._model, backend=_compile) - - def _to_msc_inputs(self, runner_name: str, *args, **kwargs) -> List[Tuple[str, Any]]: - """Change inputs to msc format - - Parameters - ---------- - runner_name: str - The runner name. - args: - The arguments. - kwargs: - The kwargs. - - Returns - ------- - inputs: - The msc format inputs. - """ - - assert not kwargs, "TorchJIT do not support kwargs" - return [("input_" + str(i), d) for i, d in enumerate(args)] - - def _from_msc_outputs(self, runner_name: str, outputs: List[Tuple[str, Any]]) -> Any: - """Change inputs from msc format - - Parameters - ---------- - runner_name: str - The runner name. - outputs: list<(str, tensor)> - The msc format outputs. - - Returns - ------- - outputs: - The framework outputs. - """ - - torch_outputs = [o[1] for o in outputs] - unpack_outputs = self.get_runner_ctx(runner_name).get("unpack_outputs", True) - if not unpack_outputs: - return torch_outputs - return torch_outputs[0] if len(torch_outputs) == 1 else torch_outputs - - def _run_ctx(self, runner_ctx: dict, inputs: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]: - """Forward by runner context - - Parameters - ---------- - runner_ctx: dict - The runner context - inputs: list<(str, tensor)> - The inputs. - - Returns - ------- - outputs: list<(str, tensor)> - The outputs. - """ - - if "runner" in runner_ctx: - runner = runner_ctx["runner"] - if runner.framework == MSCFramework.TORCH: - outputs = runner.run({i[0]: i[1] for i in inputs}, ret_type="native") - else: - outputs = runner.run({i[0]: i[1] for i in inputs}, ret_type="list") - outputs = [ - msc_utils.cast_array(o, MSCFramework.TORCH, runner.device) for o in outputs - ] - else: - torch_inputs = [i[1] for i in inputs] - outputs = runner_ctx["model"](*torch_inputs) - if isinstance(outputs, (list, tuple)) and len(outputs) == 1: - runner_ctx["unpack_outputs"] = False - if isinstance(outputs, (list, tuple)): - return [("output_" + str(i), o) for i, o in enumerate(outputs)] - return [("output", outputs)] - - @property - def framework(self): - return MSCFramework.TORCH - - @classmethod - def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bool]: - """Load the native model - - Parameters - ------- - model: - The native model. - config: dict - The config for pipeline. - - Returns - ------- - model: torch.nn.Module - The loaded native model. - device: str - The device of the model. - training: - Whether the model is for training. - """ - - return TorchRunner.load_native(model, config) - - @classmethod - def dump_nativate( - cls, - model: torch.nn.Module, - folder: msc_utils.MSCDirectory, - dump_config: Optional[dict] = None, - ) -> str: - """Dump the nativate model - - Parameters - ------- - model: torch.nn.Module - The runnable model. - folder: MSCDirectory - The export folder. - dump_config: dict - The dump config. - - Returns - ------- - export_path: str - The exported path - """ - - dump_config = dump_config or {} - assert dump_config.get("mode", "fx") == "fx", "TorchJIT only support dump nativate as fx" - return TorchRunner.dump_nativate(model, folder, dump_config) - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - return TorchRunner.support_device(device) diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py deleted file mode 100644 index 0664b9222a1b..000000000000 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ /dev/null @@ -1,340 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -# ruff: noqa: F401 -"""tvm.contrib.msc.framework.torch.runtime.runner""" - -import time -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph -from tvm.contrib.msc.core.runtime import ModelRunner -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.torch import tools -from tvm.contrib.msc.framework.torch.codegen import to_torch -from tvm.contrib.msc.framework.torch.frontend import from_torch, set_weight_alias - - -class TorchRunner(ModelRunner): - """Runner of Torch""" - - def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.runtime.Tensor]]: - """Translate IRModule to MSCgraphs - - Parameters - ------- - mod: tvm.IRModule - The module to be translated. - - Returns - ------- - graph_list: list - The translated graphs - weights_list: list> - The translated weights - """ - graphs, weights = super()._translate(mod) - return [set_weight_alias(graphs[0])], weights - - def _build_runnable(self, model: Any) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - - Returns - ------- - runnable: Any - The runnable - """ - - if self._device.startswith("cpu"): - pass - elif self._device.startswith("cuda"): - model = model.to(torch.device(self._device)) - else: - raise NotImplementedError("Unsupported device " + str(self._device)) - if self._training: - model = model.train() - else: - model = model.eval() - return model - - def _call_runnable( - self, runnable: torch.nn.Module, inputs: Dict[str, np.ndarray], device: str - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Call the runnable to get outputs - - Parameters - ------- - runnable: torch.nn.Module - The runnable model. - inputs: dict - The inputs in dict. - device: str - The device. - - Returns - ------- - outputs: list - The outputs in list. - """ - - input_names = [i["name"] for i in self.get_inputs()] - torch_inputs = [ - msc_utils.cast_array(inputs[i], MSCFramework.TORCH, device) for i in input_names - ] - return runnable(*torch_inputs) - - def _get_runtime_params(self) -> Dict[str, tvm.runtime.Tensor]: - """Get the runtime parameters - - Returns - ------- - params: dict - The parameters from runtime. - """ - - assert self._runnable, "runnable is needed to get params" - state_dict = self._runnable.state_dict() - params = {} - for graph in self._graphs: - for weight in graph.get_weights(): - assert weight.alias in state_dict, f"Missing weight {weight.alias} in state_dict" - params[weight.name] = msc_utils.cast_array( - state_dict[weight.alias], MSCFramework.TVM, "cpu" - ) - return params - - @property - def codegen_func(self): - return to_torch - - @property - def framework(self): - return MSCFramework.TORCH - - @classmethod - def load_native(cls, model: Any, config: dict) -> Tuple[torch.nn.Module, str, bool]: - """Load the native model - - Parameters - ------- - model: - The native model. - config: dict - The config for pipeline. - - Returns - ------- - model: torch.nn.Module - The loaded native model. - device: str - The device of the model. - training: - Whether the model is for training. - """ - - if isinstance(model, str) and ":" in model: - native_model = msc_utils.load_callable(model) - elif isinstance(model, torch.nn.Module): - native_model = model - else: - raise NotImplementedError( - f"Load native model {model} with type {type(model)} is not supported" - ) - parameters = list(model.parameters()) - if parameters: - ref_device = parameters[0].device - if ref_device.index: - device = f"{ref_device.type}:{ref_device.index}" - else: - device = ref_device.type - else: - device = "cpu" - return native_model, device, model.training - - @classmethod - def run_native( - cls, - model: torch.nn.Module, - inputs: Dict[str, np.ndarray], - input_names: List[str], - output_names: List[str], - warm_up: int = 10, - repeat: int = 0, - ) -> Tuple[Dict[str, np.ndarray], float]: - """Run the datas and get outputs - - Parameters - ------- - model: torch.nn.Module - The runnable model. - inputs: dict - The inputs in dict. - input_names: list - The input names. - output_names: list - The outut names. - warm_up: int - The warm_up num for profile. - repeat: int - The repeat num for profile. - - Returns - ------- - outputs: dict - The outputs in dict. - avg_time: float - The average time. - """ - - parameters = list(model.parameters()) - if parameters: - ref_dev = parameters[0].device - if ref_dev.index: - device = f"{ref_dev.type}:{ref_dev.index}" - else: - device = ref_dev.type - else: - device = "cpu" - torch_inputs = [ - msc_utils.cast_array(inputs[i], MSCFramework.TORCH, device) for i in input_names - ] - - def _run_once(): - return model(*torch_inputs) - - if repeat > 0: - for _ in range(warm_up): - _run_once() - start = time.time() - for _ in range(repeat): - outputs = _run_once() - avg_time = (time.time() - start) * 1000 / repeat - else: - outputs = _run_once() - avg_time = -1 - if isinstance(outputs, torch.Tensor): - assert len(output_names) == 1, "Expect 1 outputs, get " + str(output_names) - return {output_names[0]: msc_utils.cast_array(outputs)}, avg_time - assert len(output_names) == len(outputs), ( - f"Outputs mismatch, {output_names} with {len(outputs)}" - ) - outputs = { - o_name: msc_utils.cast_array(o_data) for o_name, o_data in zip(output_names, outputs) - } - return outputs, avg_time - - @classmethod - def dump_nativate( - cls, - model: torch.nn.Module, - folder: msc_utils.MSCDirectory, - dump_config: Optional[dict] = None, - ) -> str: - """Dump the nativate model - - Parameters - ------- - model: torch.nn.Module - The runnable model. - folder: MSCDirectory - The export folder. - dump_config: dict - The dump config. - - Returns - ------- - export_path: str - The exported path - """ - - dump_config = dump_config or {} - mode = dump_config.get("mode", "fx") - if mode == "fx": - graph_model = torch.fx.symbolic_trace(model) - exp_path = folder.create_dir("model") - graph_model.to_folder(exp_path.path, "native_model") - return exp_path.relpath("module.py") + ":native_model" - if mode == "pt": - assert "inputs" in dump_config, "inputs are needed for torch.jit.trace" - parameters = list(model.parameters()) - device = parameters[0].device if parameters else torch.device("cpu") - datas = [np.random.rand(i[1]).astype(i[2]) for i in dump_config["inputs"]] - torch_datas = [torch.from_numpy(d).to(device) for d in datas] - with torch.no_grad(): - scriptde_model = torch.jit.trace(model, tuple(torch_datas)).eval() - exp_path = folder.relpath("model.pt") - torch.jit.save(scriptde_model, exp_path) - return exp_path - raise TypeError("Unexpeceted dump mode " + str(mode)) - - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - config["parse"]["parser"] = from_torch - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - {"input_info": [[i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"]]} - ) - config["parse"]["parse_config"] = parse_config - return config - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - return torch.cuda.is_available() - return False diff --git a/python/tvm/contrib/msc/framework/torch/tools/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/__init__.py deleted file mode 100644 index 87dfade83f08..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools""" - -from .prune import * -from .quantize import * -from .distill import * -from .track import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py deleted file mode 100644 index 61ff8cc3ef1a..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.distill""" - -from .distiller import * -from .method import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py deleted file mode 100644 index 8c5849500970..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py +++ /dev/null @@ -1,146 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.distill.distiller""" - -from typing import Any, Dict - -import torch -from torch import optim - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.distill import BaseDistiller -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TorchDistillerFactory: - """Distiller factory for torch""" - - def create(self, base_cls: BaseDistiller) -> BaseDistiller: - """Create adaptive distiller - - Parameters - ---------- - base_cls: BaseDistiller - The base distiller class - - Returns - ------- - distiller_cls: BaseDistiller - The distiller class. - """ - - @msc_utils.register_tool - class Distiller(base_cls): - """Adaptive distiller for torch""" - - def build_model(self, teacher: Any, student: Any) -> Any: - """Build the model with teacher and student - - Parameters - ------- - teacher: Any - The teacher model - student: Any - The student model - - Returns - ------- - model: Any - The built model. - """ - - optimizer = self._options.get("optimizer", "sgd") - opt_config = {"lr": 0.0001, "weight_decay": 1e-4} - opt_config.update(self._options.get("opt_config", {})) - self._logger.debug( - "%s build model with optimizer %s(%s)", - self.tool_type().upper(), - optimizer, - opt_config, - ) - if optimizer == "sgd": - self._optimizer = optim.SGD(student.parameters(), **opt_config) - elif optimizer == "adam": - self._optimizer = optim.Adam(student.parameters(), **opt_config) - else: - raise NotImplementedError(f"optimizer {optimizer} is not supported") - - # Get loss function - loss_strategy = self._strategys.get("loss") - assert loss_strategy, "Can not find loss in strategys" - - def get_loss(teacher_outputs, student_outputs): - return loss_strategy(self, teacher_outputs, student_outputs) - - # Build model - class DistillModel(torch.nn.Module): - """Common distill model class""" - - def __init__(self): - super().__init__() - self.teacher = teacher - self.student = student - - def forward(self, *inputs): - with torch.no_grad(): - teacher_outputs = self.teacher.forward(*inputs) - student_outputs = self.student.forward(*inputs) - return get_loss(teacher_outputs, student_outputs) - - self._model = DistillModel() - return self._model - - def _learn(self, loss: torch.Tensor): - """Learn after forward - - Parameters - ------- - loss: torch.Tensor - The loss after forward - """ - - loss.backward() - self._optimizer.step() - return loss - - def _distill(self) -> Dict[str, Any]: - """Distill the knowledge - - Returns - ------- - weights: dict - The distilled weights. - """ - - state_dict = self._model.student.state_dict() - return { - n: state_dict.get(self.find_tensor(n).alias, d) - for n, d in self._weights.items() - } - - @classmethod - def framework(cls): - return MSCFramework.TORCH - - return Distiller - - -factory = TorchDistillerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py b/python/tvm/contrib/msc/framework/torch/tools/distill/method.py deleted file mode 100644 index 3c7f43afa9ed..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/method.py +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.framework.torch.tools.distill.method""" - -from typing import List - -import torch - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.distill import BaseDistiller, DistillMethod -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -@msc_utils.register_tool_method -class TorchDistillMethod(DistillMethod): - """Default quantize method for torch""" - - @classmethod - def loss_kl_divergence( - cls, - distiller: BaseDistiller, - t_outputs: List[torch.Tensor], - s_outputs: List[torch.Tensor], - temperature: int = 5, - softmax_dim: int = -1, - ): - """Calculate loss with mse - - Parameters - ---------- - distiller: BaseDistiller - The distiller - t_outputs: list - The teacher outputs. - s_outputs: list - The student outputs. - temperature: int - The temperature factor. - softmax_dim: int - If >=0, use softmax_dim for softmax loss - - Returns - ------- - loss: float - The loss. - """ - - kd_loss, loss = torch.nn.KLDivLoss(), 0 - if softmax_dim >= 0: - log_softmax = torch.nn.LogSoftmax(dim=softmax_dim) - softmax = torch.nn.Softmax(dim=softmax_dim) - - def _distill_loss(t_out, s_out): - if softmax_dim >= 0: - return ( - temperature - * temperature - * kd_loss(log_softmax(s_out / temperature), softmax(t_out / temperature)) - ) - return kd_loss(s_out / temperature, t_out / temperature) - - for t_out, s_out in zip(t_outputs, s_outputs): - loss += _distill_loss(t_out, s_out) - return loss - - @classmethod - def loss_lp_norm( - cls, - distiller: BaseDistiller, - t_outputs: List[torch.Tensor], - s_outputs: List[torch.Tensor], - power: int = 2, - ): - """Calculate loss with mse - - Parameters - ---------- - distiller: BaseDistiller - The distiller - t_outputs: list - The teacher outputs. - s_outputs: list - The student outputs. - power: int - The power factor. - - Returns - ------- - loss: float - The loss. - """ - - loss = 0 - for t_out, s_out in zip(t_outputs, s_outputs): - loss += torch.pow((t_out - s_out).abs(), power).mean() - return loss - - @classmethod - def framework(cls): - return MSCFramework.TORCH diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py deleted file mode 100644 index 6364e14945aa..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/prune/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.prune""" - -from .pruner import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py deleted file mode 100644 index 1ced1399f3e5..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/prune/pruner.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.prune.pruner""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.prune import BasePruner -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TorchPrunerFactory: - """Pruner factory for torch""" - - def create(self, base_cls: BasePruner) -> BasePruner: - """Create adaptive pruner - - Parameters - ---------- - base_cls: BasePruner - The base pruner class - - Returns - ------- - pruner_cls: BasePruner - The pruner class. - """ - - @msc_utils.register_tool - class Pruner(base_cls): - """Adaptive pruner for torch""" - - @classmethod - def framework(cls): - return MSCFramework.TORCH - - return Pruner - - -factory = TorchPrunerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py deleted file mode 100644 index 8687626dfefc..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.quantize""" - -from .quantizer import * -from .method import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py deleted file mode 100644 index 8c1397c6d1f6..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/method.py +++ /dev/null @@ -1,268 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument, arguments-differ -"""tvm.contrib.msc.framework.torch.tools.quantize.method""" - -from functools import wraps - -import numpy as np -import torch -from torch.autograd import Function - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeMethod -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -def fake_quantize(func): - """Fake quantize without backward""" - - @wraps(func) - def wrapper( - cls, quantizer: BaseQuantizer, data: torch.Tensor, name: str, consumer: str, *args, **kwargs - ): - func_name = "quantize_func." + func.__name__ - quantize_func = quantizer._get_tensor_cache(name, consumer, func_name) - if quantize_func is None: - - class FakeQuantize(Function): - """Fake quantize func for torch""" - - @staticmethod - def forward(ctx, data): - return func(cls, quantizer, data, name, consumer, *args, **kwargs) - - @staticmethod - def backward(ctx, grad_outputs): - return grad_outputs - - quantize_func = quantizer._save_tensor_cache(name, consumer, func_name, FakeQuantize) - return quantize_func.apply(data) - - return wrapper - - -@msc_utils.register_tool_method -class TorchQuantizeMethod(QuantizeMethod): - """Default quantize method for torch""" - - @classmethod - def amplify_data( - cls, - data: torch.Tensor, - scale: float, - min_val: float, - max_val: float, - rounding: str = "round", - ) -> torch.Tensor: - """Amplify the data - - Parameters - ---------- - data: torch.Tensor - The source data. - scale: float - The scale factor - min_val: float - The min. - max_val: float - The max. - rounding: str - The round method - - Returns - ------- - data: torch.Tensor - The processed data. - """ - - if rounding == "null": - return torch.clamp(data * scale, min_val, max_val) - if rounding == "floor": - return torch.clamp(torch.floor(data * scale), min_val, max_val) - if rounding == "ceil": - return torch.clamp(torch.ceil(data * scale), min_val, max_val) - if rounding == "round": - return torch.clamp(torch.round(data * scale), min_val, max_val) - if rounding == "trunc": - return torch.clamp(torch.trunc(data * scale), min_val, max_val) - if rounding == "logic_round": - data = torch.clamp(data * scale, min_val, max_val) - negative_ceil = torch.where( - torch.logical_and(data < 0, (data - torch.floor(data)) == 0.5), torch.ceil(data), 0 - ) - data = torch.where( - torch.logical_and(data < 0, (data - torch.floor(data)) == 0.5), 0, data - ) - data = torch.where((data - torch.floor(data)) >= 0.5, torch.ceil(data), data) - data = torch.where((data - torch.floor(data)) < 0.5, torch.floor(data), data) - return data + negative_ceil - raise TypeError("Unexpected rounding " + str(rounding)) - - @classmethod - def gather_maxmin( - cls, - quantizer: BaseQuantizer, - data: torch.Tensor, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - ) -> dict: - """Gather the data by max/min - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - abs_max_list = plan.get("abs_max_list", []) - abs_max_list.append(float(torch.abs(data).max())) - max_list = plan.get("max_list", []) - max_list.append(float(data.max())) - min_list = plan.get("min_list", []) - min_list.append(float(data.min())) - return { - "abs_max_list": abs_max_list, - "max_list": max_list, - "min_list": min_list, - "calibrated": False, - } - - @classmethod - def gather_max_per_channel( - cls, - quantizer: BaseQuantizer, - data: torch.Tensor, - name: str, - consumer: str, - plan: dict, - nbits: int = 8, - channel: str = "O", - auto_unsign: bool = False, - ) -> dict: - """Gather the data by max_per_channel - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - plan: dict - The pre-calibrated plan. - nbits: int - The number bits for quantize. - channel: str - The channel reference. - auto_unsign: bool - Whether to use auto unsign. - - Returns - ------- - plan: dict - The plan of the tensor. - """ - - weight = quantizer.find_tensor(name) - axis = weight.layout_of(channel) - channel_max = [torch.abs(d).max() for d in torch.chunk(data, data.shape[axis], dim=axis)] - sign = data.min() < 0 if auto_unsign else True - valid_range = 2 ** (nbits - int(sign)) - 1 - scale = [valid_range / float(m) for m in channel_max] - return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True} - - @classmethod - @fake_quantize - def quantize_normal( - cls, - quantizer: BaseQuantizer, - data: torch.Tensor, - name: str, - consumer: str, - scale: float, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> torch.Tensor: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: torch.Tensor - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - data: torch.Tensor - The processed tensor. - """ - - valid_range = 2 ** (nbits - int(sign)) - 1 - min_val = -valid_range if sign else 0 - scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") - if scale_tensor is None: - scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon) - if isinstance(scale_tensor, np.ndarray): - scale_tensor = torch.from_numpy(scale_tensor).to(data.device) - quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) - data = cls.amplify_data(data, scale_tensor, min_val, valid_range, rounding) - return data / scale_tensor - - @classmethod - def framework(cls): - return MSCFramework.TORCH diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py deleted file mode 100644 index 0c93a444b389..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.quantize.quantizer""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TorchQuantizerFactory: - """Quantizer factory for torch""" - - def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: - """Create adaptive quantizer - - Parameters - ---------- - base_cls: BaseQuantizer - The base quantizer class - - Returns - ------- - quantizer_cls: BaseQuantizer - The quantizer class. - """ - - @msc_utils.register_tool - class Quantizer(base_cls): - """Adaptive quantizer for torch""" - - @classmethod - def framework(cls): - return MSCFramework.TORCH - - return Quantizer - - -factory = TorchQuantizerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/torch/tools/track/__init__.py b/python/tvm/contrib/msc/framework/torch/tools/track/__init__.py deleted file mode 100644 index 55951fe3a97c..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/track/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.track""" - -from .tracker import * diff --git a/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py b/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py deleted file mode 100644 index 8a245e04eeba..000000000000 --- a/python/tvm/contrib/msc/framework/torch/tools/track/tracker.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.torch.tools.track.tracker""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.tools.track import BaseTracker -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TorchTrackerFactory: - """Tracker factory for torch""" - - def create(self, base_cls: BaseTracker) -> BaseTracker: - """Create adaptive tracker - - Parameters - ---------- - base_cls: BaseTracker - The base tracker class - - Returns - ------- - tracker_cls: BaseTracker - The tracker class. - """ - - @msc_utils.register_tool - class Tracker(base_cls): - """Adaptive tracker for torch""" - - @classmethod - def framework(cls): - return MSCFramework.TORCH - - return Tracker - - -factory = TorchTrackerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/__init__.py b/python/tvm/contrib/msc/framework/tvm/__init__.py deleted file mode 100644 index fcdf0c886c24..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm""" diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py deleted file mode 100644 index c9f63e21eaef..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.framework.tvm", __name__) diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/__init__.py b/python/tvm/contrib/msc/framework/tvm/codegen/__init__.py deleted file mode 100644 index ca656cde89db..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/codegen/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.codegen""" - -from .codegen import * diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py deleted file mode 100644 index e5ed5d381a39..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.codegen.codegen""" - -from typing import Any, Dict, Optional - -import tvm -from tvm.contrib.msc.core import codegen as msc_codegen -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.ir import MSCGraph - - -def to_relax( - graph: MSCGraph, - weights: Optional[Dict[str, tvm.runtime.Tensor]] = None, - codegen_config: Optional[Dict[str, str]] = None, - print_config: Optional[Dict[str, str]] = None, - build_folder: msc_utils.MSCDirectory = None, - plugin: Any = None, -) -> tvm.IRModule: - """Change MSCGraph to IRModule. - - Parameters - ---------- - graph: tvm.contrib.msc.core.ir.MSCGraph - The translated graph. - weights: dict of - The parameters of the IRModule. - codegen_config: dict - The config for codegen. - print_config: dict - The config for print. - build_folder: MSCDirectory - The folder for saving scripts and datas. - plugin: PluginManager - The plugin manager. - - Returns - ------- - mod: IRModule - The IRModule of relax. - """ - - return msc_codegen.to_relax(graph, weights, codegen_config, print_config, build_folder, plugin) diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/__init__.py b/python/tvm/contrib/msc/framework/tvm/runtime/__init__.py deleted file mode 100644 index 73ea8bb06c3b..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/runtime/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.runtime""" - -from .runner import * diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py deleted file mode 100644 index cb242d2cbd12..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ /dev/null @@ -1,329 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -# ruff: noqa: F401 -"""tvm.contrib.msc.framework.runtime.tvm.runner""" - -import os -import time -from typing import Any, Dict, List, Tuple, Union - -import numpy as np - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.runtime import ModelRunner -from tvm.contrib.msc.core.tools import execute_step -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.tvm import tools -from tvm.contrib.msc.framework.tvm.codegen import to_relax - - -class WrapRunnable: - """Wrapped runnable for tools - - Parameters - ------- - runner: ModelRunner - The runner context - runnable: tvm.relax.VirtualMachine - The virtual machine. - entry: str - The entry funcname. - """ - - def __init__(self, runnable: tvm.relax.VirtualMachine, entry: str = "main"): - self._runnable = runnable - self._entry = entry - - def __call__(self, *inputs) -> List[tvm.runtime.Tensor]: - execute_step("before_forward", *inputs) - output = self._runnable[self._entry](*inputs) - return execute_step("after_forward", output) - - -class TVMRunner(ModelRunner): - """Runner of Relax""" - - def setup(self) -> dict: - """Setup the runner - - Returns - ------- - info: dict - The setup info. - """ - - self._executable = None - return super().setup() - - def _build_runnable(self, model: Any) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - - Returns - ------- - runnable: Any - The runnable - """ - - if self._training: - model = tvm.relax.transform.DecomposeOpsForTraining()(model) - else: - model = tvm.relax.transform.DecomposeOpsForInference()(model) - if "builder" in self._generate_config: - builder, build_config = self._generate_config["builder"] - runnable = builder(model, **build_config) - self._logger.info( - f"Model({self.framework}) processed by customize builder {builder}({build_config})" - ) - else: - model = tvm.relax.transform.LegalizeOps()(model) - if self._device.startswith("cpu"): - target = tvm.target.Target("llvm") - with tvm.transform.PassContext(opt_level=3): - self._executable = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(self._executable, tvm.cpu()) - elif self._device.startswith("cuda"): - target = tvm.target.Target("cuda") - with target: - model = tvm.s_tir.transform.DefaultGPUSchedule()(model) - with tvm.transform.PassContext(opt_level=3): - self._executable = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(self._executable, tvm.cuda()) - else: - raise NotImplementedError("Unsupported device " + str(self._device)) - return WrapRunnable(runnable) - - def _call_runnable( - self, runnable: WrapRunnable, inputs: Dict[str, np.ndarray], device: str - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Call the runnable to get outputs - - Parameters - ------- - runnable: tvm.relax.VirtualMachine - The virtual machine. - inputs: dict - The inputs in dict. - device: str - The device. - - Returns - ------- - outputs: list - The outputs in list. - """ - - input_names = [i["name"] for i in self.get_inputs()] - tvm_inputs = [ - msc_utils.cast_array(inputs[i], MSCFramework.TVM, device) for i in input_names - ] - return runnable(*tvm_inputs) - - def export_runnable(self, folder: msc_utils.MSCDirectory) -> dict: - """Export the runnable - - Parameters - ------- - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The runnable info. - """ - - export_lib = folder.relpath("lib.so") - self._executable.export_library(export_lib) - return { - "lib": export_lib, - "device": self.device, - "model_type": self.framework, - "abstract": self.model_info, - } - - @property - def codegen_func(self): - return to_relax - - @property - def framework(self): - return MSCFramework.TVM - - @classmethod - def load_native(cls, model: Any, config: dict) -> Tuple[tvm.IRModule, str, bool]: - """Load the native model - - Parameters - ------- - model: - The native model. - config: dict - The config for pipeline. - - Returns - ------- - model: tvm.IRModule - The loaded native model. - device: str - The device of the model. - training: bool - Whether the model is for training. - """ - - if isinstance(model, str) and os.path.isfile(model): - with open(model) as f: - native_model = tvm.ir.load_json(f.read()) - elif isinstance(model, tvm.IRModule): - native_model = model - else: - raise NotImplementedError( - f"Load native model {model} with type {type(model)} is not supported" - ) - if tvm.cuda().exist: - device = "cuda" - else: - device = "cpu" - return native_model, device, False - - @classmethod - def run_native( - cls, - model: tvm.IRModule, - inputs: Dict[str, np.ndarray], - input_names: List[str], - output_names: List[str], - warm_up: int = 10, - repeat: int = 0, - ) -> Tuple[Dict[str, np.ndarray], float]: - """Run the datas and get outputs - - Parameters - ------- - model: tvm.IRModule - The runnable model. - inputs: dict - The inputs in dict. - input_names: list - The input names. - output_names: list - The outut names. - warm_up: int - The warm_up num for profile. - repeat: int - The repeat num for profile. - - Returns - ------- - outputs: dict - The outputs in dict. - avg_time: float - The average time. - """ - - model = tvm.relax.transform.LegalizeOps()(model) - if tvm.cuda().exist: - target = tvm.target.Target("cuda") - with target: - model = tvm.s_tir.transform.DefaultGPUSchedule()(model) - with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) - tvm_inputs = [tvm.runtime.tensor(inputs[i], device=tvm.cuda()) for i in input_names] - else: - target = tvm.target.Target("llvm") - with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.compile(model, target) - runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) - tvm_inputs = [tvm.runtime.tensor(inputs[i]) for i in input_names] - - def _run_once(): - return runnable["main"](*tvm_inputs) - - if repeat > 0: - for _ in range(warm_up): - _run_once() - start = time.time() - for _ in range(repeat): - outputs = _run_once() - avg_time = (time.time() - start) * 1000 / repeat - else: - outputs = _run_once() - avg_time = -1 - if isinstance(outputs, tvm.runtime.Tensor): - outputs = [outputs] - assert len(output_names) == len(outputs), ( - f"Outputs mismatch, {output_names} with {len(outputs)}" - ) - outputs = { - o_name: msc_utils.cast_array(o_data) for o_name, o_data in zip(output_names, outputs) - } - return outputs, avg_time - - @classmethod - def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: - """Update the config for parse - - Parameters - ------- - stage: str - The stage to be updated - config: dict - The config for pipeline. - model: - The native model. - - Returns - ------- - config: dict - The updated config. - """ - - config = ModelRunner.update_config(stage, config, model) - if stage not in config: - return config - if stage == MSCStage.PARSE: - # pylint: disable=unused-argument - def passby(mod, *args, **kwargs): - return mod, None - - config["parse"]["parser"] = passby - return config - - @classmethod - def support_device(cls, device: str) -> bool: - """Check if the device is enabled - - Returns - ------- - enabled: bool - Whether the device is enabled. - """ - - if device == "cpu": - return True - if device.startswith("cuda"): - dev_id = int(device.split(":")[1]) if ":" in device else 0 - return tvm.cuda(dev_id).exist - return False diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py deleted file mode 100644 index f904d19ec8ad..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools""" - -from .prune import * -from .quantize import * -from .distill import * -from .track import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py deleted file mode 100644 index 8d4b7dfb6158..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/distill/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.distill""" - -from .distiller import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py deleted file mode 100644 index 238584a4db05..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/distill/distiller.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.distill.distiller""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.distill import BaseDistiller -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TVMDistillerFactory: - """Distiller factory for tvm""" - - def create(self, base_cls: BaseDistiller) -> BaseDistiller: - """Create adaptive distiller - - Parameters - ---------- - base_cls: BaseDistiller - The base distiller class - - Returns - ------- - distiller_cls: BaseDistiller - The distiller class. - """ - - @msc_utils.register_tool - class Distiller(base_cls): - """Adaptive distiller for tvm""" - - @classmethod - def framework(cls): - return MSCFramework.TVM - - return Distiller - - -factory = TVMDistillerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.DISTILLER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py deleted file mode 100644 index 47c1478611e8..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/prune/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.prune""" - -from .pruner import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py b/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py deleted file mode 100644 index d393e5f9e6a7..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/prune/pruner.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.prune.pruner""" - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.prune import BasePruner -from tvm.contrib.msc.core.tools.tool import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TVMPrunerFactory: - """Pruner factory for tvm""" - - def create(self, base_cls: BasePruner) -> BasePruner: - """Create adaptive pruner - - Parameters - ---------- - base_cls: BasePruner - The base pruner class - - Returns - ------- - pruner_cls: BasePruner - The pruner class. - """ - - @msc_utils.register_tool - class Pruner(base_cls): - """Adaptive pruner for tvm""" - - @classmethod - def framework(cls): - return MSCFramework.TVM - - return Pruner - - -factory = TVMPrunerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.PRUNER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py deleted file mode 100644 index dac2c82ed75d..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.quantize""" - -from .quantizer import * -from .method import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py deleted file mode 100644 index f111573843bb..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py +++ /dev/null @@ -1,204 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.framework.tvm.tools.quantize.method""" - -from typing import Tuple - -import numpy as np - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeMethod -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.relax import op as relax_op - - -@msc_utils.register_tool_method -class TVMQuantizeMethod(QuantizeMethod): - """Default quantize method for tvm""" - - @classmethod - def get_quantize_cache( - cls, - quantizer: BaseQuantizer, - data: tvm.relax.Var, - name: str, - consumer: str, - scale: float, - axis: int = -1, - epsilon: float = 1.0 / (1 << 24), - ) -> Tuple[tvm.relax.Constant, tvm.relax.Constant]: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: tvm.relax.Var - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - axis: int - The axis. - epsilon: float - The epsilon for get scale. - - Returns - ------- - scale_tensor: tvm.relax.Constant - The scale_tensor. - zero_point: tvm.relax.Constant - The zero_point. - """ - - name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) - scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") - zero_point = quantizer._get_tensor_cache(name, consumer, "zero_point") - if scale_tensor is None: - scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon, expand_dims=False) - scale_tensor = 1 / scale_tensor - if isinstance(scale_tensor, float): - scale_tensor = np.array(scale_tensor) - scale_tensor = scale_tensor.astype(quantizer.find_tensor(name).dtype_name) - zero_point = np.zeros_like(scale_tensor).astype("int8") - scale_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_scale") - scale_tensor = tvm.relax.Constant(tvm.runtime.tensor(scale_tensor), span=scale_span) - zp_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_zero_point") - zero_point = tvm.relax.Constant(tvm.runtime.tensor(zero_point), span=zp_span) - quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) - quantizer._save_tensor_cache(name, consumer, "zero_point", zero_point) - return scale_tensor, zero_point - - @classmethod - def quantize_normal( - cls, - quantizer: BaseQuantizer, - data: tvm.relax.Var, - name: str, - consumer: str, - scale: float, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> tvm.relax.Var: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: tvm.relax.Var - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - data: tvm.relax.Var - The processed tensor. - """ - - if nbits == 8: - dtype = "int8" - else: - raise TypeError("Unexpected nbits " + str(nbits)) - name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) - scale_tensor, zero_point = cls.get_quantize_cache( - quantizer, data, name, consumer, scale, axis, epsilon - ) - expr = relax_op.quantize(data, scale_tensor, zero_point, axis, dtype) - return quantizer._block_builder.emit(expr, name_hint=name_prefix + "_quantize") - - @classmethod - def dequantize_normal( - cls, - quantizer: BaseQuantizer, - data: tvm.relax.Var, - name: str, - consumer: str, - scale: float = -1.0, - nbits: int = 8, - axis: int = -1, - sign: bool = True, - rounding: str = "round", - epsilon: float = 1.0 / (1 << 24), - ) -> tvm.relax.Var: - """Calibrate the data by kl_divergence - - Parameters - ---------- - quantizer: BaseQuantizer - The quantizer - data: np.ndarray - The source data. - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scale: float - The scale factor - nbits: int - The number bits for quantize. - axis: int - The axis. - sign: bool - Whether to use sign. - rounding str - The rounding method. - epsilon: float - The epsilon for get scale. - - Returns - ------- - data: array like - The processed tensor. - """ - - name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) - scale_tensor, zero_point = cls.get_quantize_cache( - quantizer, data, name, consumer, scale, axis, epsilon - ) - expr = relax_op.dequantize( - data, scale_tensor, zero_point, axis, quantizer.find_tensor(name).dtype - ) - return quantizer._block_builder.emit(expr, name_hint=name_prefix + "_dequantize") - - @classmethod - def framework(cls): - return MSCFramework.TVM diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py deleted file mode 100644 index 703d8efb4b2b..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py +++ /dev/null @@ -1,169 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -# ruff: noqa: RUF005 -"""tvm.contrib.msc.framework.tvm.tools.quantize.quantizer""" - -from typing import List, Union - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.quantize import BaseQuantizer -from tvm.contrib.msc.core.tools.tool import ToolStrategy, ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TVMQuantizerFactory: - """Quantizer factory for tvm""" - - def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: - """Create adaptive quantizer - - Parameters - ---------- - base_cls: BaseQuantizer - The base quantizer class - - Returns - ------- - quantizer_cls: BaseQuantizer - The quantizer class. - """ - - @msc_utils.register_tool - class Quantizer(base_cls): - """Adaptive quantizer for tvm""" - - def _execute_before_build(self, block_builder: tvm.relax.BlockBuilder): - """Execute before model build - - Parameters - ---------- - block_builder: tvm.relax.BlockBuilder - The block builder. - """ - - self._block_builder = block_builder - self._gather_tensors, self._gather_names = {}, [] - super()._execute_before_build(block_builder) - - def _execute_after_build( - self, output: Union[tvm.relax.Var, List[tvm.relax.DataflowVar]] - ) -> List[tvm.relax.Var]: - """Execute after model build - - Parameters - ---------- - output: var or list - The output var of the model. - - Returns - ------- - outputs: list - The modified outputs var. - """ - - if self._calibrated: - return super()._execute_after_build(output) - self._gather_names = list(sorted(self._gather_tensors.keys())) - gather_tensors = [self._gather_tensors[o]["tensor"] for o in self._gather_names] - if isinstance(output, tvm.relax.Var): - return super()._execute_after_build([output] + gather_tensors) - return super()._execute_after_build(output + gather_tensors) - - def _execute_after_forward( - self, outputs: List[tvm.runtime.Tensor] - ) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]: - """Execute after model forward - - Parameters - ---------- - outputs: list - The output datas. - - Returns - ------- - output: np.ndarray or list - The modified output ndarray. - """ - - if self._calibrated: - return super()._execute_after_forward(outputs) - output_num = len(outputs) - len(self._gather_names) - for data, name in zip(outputs[output_num:], self._gather_names): - info = self._gather_tensors[name] - for consumer in info["consumers"]: - strategys = self._get_tensor_strategys(name, consumer) - self._gather_tensor(data, name, consumer, strategys) - if output_num == 1: - return super()._execute_after_forward(outputs[0]) - return super()._execute_after_forward(outputs[:output_num]) - - def _process_tensor( - self, - tensor: tvm.relax.DataflowVar, - name: str, - consumer: str, - scope: str, - strategys: List[ToolStrategy], - ) -> tvm.relax.DataflowVar: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if not self._calibrated: - if self.is_weight(name): - return self._gather_tensor(self.get_data(name), name, consumer, strategys) - if name not in self._gather_tensors: - self._gather_tensors[name] = { - "consumers": [consumer], - "tensor": tensor, - } - self._gather_names.append(name) - else: - self._gather_tensors[name]["consumers"].append(consumer) - return tensor - return self._quantize_tensor(tensor, name, consumer, strategys) - - @classmethod - def framework(cls): - return MSCFramework.TVM - - return Quantizer - - -factory = TVMQuantizerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/track/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/track/__init__.py deleted file mode 100644 index f99f09767dbf..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/track/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.framework.tvm.tools.track""" - -from .tracker import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py deleted file mode 100644 index 0e29ddcb93b3..000000000000 --- a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py +++ /dev/null @@ -1,160 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -# ruff: noqa: RUF005 -"""tvm.contrib.msc.framework.tvm.tools.track.tracker""" - -from typing import List, Union - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools.tool import ToolStrategy, ToolType -from tvm.contrib.msc.core.tools.track import BaseTracker -from tvm.contrib.msc.core.utils.namespace import MSCFramework - - -class TVMTrackerFactory: - """Tracker factory for tvm""" - - def create(self, base_cls: BaseTracker) -> BaseTracker: - """Create adaptive tracker - - Parameters - ---------- - base_cls: BaseTracker - The base tracker class - - Returns - ------- - tracker_cls: BaseTracker - The tracker class. - """ - - @msc_utils.register_tool - class Tracker(base_cls): - """Adaptive tracker for tvm""" - - def _execute_before_build(self, block_builder: tvm.relax.BlockBuilder): - """Execute before model build - - Parameters - ---------- - block_builder: tvm.relax.BlockBuilder - The block builder. - """ - - self._block_builder = block_builder - self._track_tensors, self._track_names = {}, [] - super()._execute_before_build(block_builder) - - def _execute_after_build( - self, output: Union[tvm.relax.Var, List[tvm.relax.DataflowVar]] - ) -> List[tvm.relax.Var]: - """Execute after model build - - Parameters - ---------- - output: var or list - The output var of the model. - - Returns - ------- - outputs: list - The modified outputs var. - """ - - self._track_names = list(sorted(self._track_tensors.keys())) - track_tensors = [self._track_tensors[o]["tensor"] for o in self._track_names] - if isinstance(output, tvm.relax.Var): - return super()._execute_after_build([output] + track_tensors) - return super()._execute_after_build(output + track_tensors) - - def _execute_after_forward( - self, outputs: List[tvm.runtime.Tensor] - ) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]: - """Execute after model forward - - Parameters - ---------- - outputs: list - The output datas. - - Returns - ------- - output: np.ndarray or list - The modified output ndarray. - """ - - output_num = len(outputs) - len(self._track_names) - for data, name in zip(outputs[output_num:], self._track_names): - consumer = self._track_tensors[name]["consumer"] - strategys = self._get_tensor_strategys(name, consumer) - producer = self.find_producer(name) - if producer == "nn.batch_norm": - data = data[0] - self._track_tensor(data, name, consumer, strategys) - if output_num == 1: - return super()._execute_after_forward(outputs[0]) - return super()._execute_after_forward(outputs[:output_num]) - - def _process_tensor( - self, - tensor: tvm.relax.DataflowVar, - name: str, - consumer: str, - scope: str, - strategys: List[ToolStrategy], - ) -> tvm.relax.DataflowVar: - """Process tensor - - Parameters - ------- - tensor: Any - Tensor in framework - name: str - The name of the tensor. - consumer: str - The name of the consumer. - scope: str - The scope mark teacher| student| null. - strategys: list - The strategys for the tensor. - - Returns - ------- - tensor: Any - The processed tensor. - """ - - if self.is_weight(name): - self._track_tensor(self.get_data(name), name, consumer, strategys) - if name not in self._track_tensors: - self._track_tensors[name] = {"consumer": consumer, "tensor": tensor} - self._track_names.append(name) - return tensor - - @classmethod - def framework(cls): - return MSCFramework.TVM - - return Tracker - - -factory = TVMTrackerFactory() -tools = msc_utils.get_registered_tool(MSCFramework.MSC, ToolType.TRACKER, tool_style="all") -for tool in tools.values(): - factory.create(tool) diff --git a/python/tvm/contrib/msc/pipeline/__init__.py b/python/tvm/contrib/msc/pipeline/__init__.py deleted file mode 100644 index b27b09d5d764..000000000000 --- a/python/tvm/contrib/msc/pipeline/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.pipeline""" - -from .manager import * -from .wrapper import * diff --git a/python/tvm/contrib/msc/pipeline/dynamic.py b/python/tvm/contrib/msc/pipeline/dynamic.py deleted file mode 100644 index 9fe066de074b..000000000000 --- a/python/tvm/contrib/msc/pipeline/dynamic.py +++ /dev/null @@ -1,493 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.pipeline.dynamic""" - -from typing import Any, List, Optional, Tuple - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.runtime import BaseJIT -from tvm.contrib.msc.core.utils.message import MSCStage - -from .pipeline import BasePipeline -from .worker import MSCPipeWorker - - -class MSCDynamic(BasePipeline): - """Dynamic of Pipeline, process dynamic model""" - - def setup(self) -> dict: - """Setup the pipeline - - Returns - ------- - info: dict - The setup info. - """ - - self._jit, self._jit_caches = None, {} - self._worker_ctxs = {} - return super().setup() - - def change_stage(self, stage: str, log_stage: bool = True) -> str: - """Change stage - - Parameters - ---------- - stage: str - The stage name. - log_stage: bool - Whether to log the stage. - - Returns - ------- - stage: str - The stage name. - """ - - self._jit_caches = {} - return super().change_stage(stage, log_stage) - - def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: - """Prepare datas for the pipeline. - - Parameters - ---------- - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of prepare. - report: dict - The report of prepare. - """ - - hooks = {"pre_forward": [self.pre_forward], "post_forward": [self.post_forward]} - if isinstance(self._model, dict) and "model" in self._model: - worker_models = self._model["worker_models"] - self._model, device, training = self.jit_cls.load_native( - self._model["model"], self._config - ) - else: - worker_models = {} - self._model, device, training = self.jit_cls.load_native(self._model, self._config) - self._jit = self.jit_cls( - self._model, - inputs=[i[0] for i in self._config["inputs"]], - outputs=self._config["outputs"], - device=device, - training=training, - hooks=hooks, - logger=self._logger, - ) - self._jit.build() - assert MSCStage.PREPARE in self._config["dataset"], "prepare dataset is needed" - cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) - for inputs in data_loader(): - if cnt >= max_golden > 0: - break - self._jit.run(inputs) - cnt += 1 - - # create workers - def _get_worker_config(name: str, cache: dict): - saver = cache.get("saver") - assert saver, "Failed to record datas for " + name - saver.finalize() - - def _to_input(i_name): - i_info = saver.info["inputs"][i_name] - return (i_name, i_info["shape"], i_info["dtype"]) - - w_config = msc_utils.copy_dict(self._config) - w_config.update( - { - "inputs": [_to_input(i) for i in saver.info["input_names"]], - "outputs": saver.info["output_names"], - } - ) - w_config["dataset"]["golden"] = {"loader": saver.folder} - for tool in w_config.get("tools", []): - worker_config = tool.get("worker_configs", {}).get(name) - if worker_config: - tool["tool_config"] = msc_utils.update_dict(tool["tool_config"], worker_config) - return w_config - - info, report = {}, {} - for name, cache in self._jit_caches.items(): - runner_ctx = self._jit.get_runner_ctx(name) - w_model = worker_models.get(name, runner_ctx["model"]) - self._worker_ctxs[name] = { - "worker": self.create_worker(w_model, name, _get_worker_config(name, cache)), - "workspace": self._workspace.create_dir(name), - } - with msc_utils.change_workspace(self._worker_ctxs[name]["workspace"]): - info[name], report[name] = self._worker_ctxs[name]["worker"].prepare() - return info, report - - def _parse(self) -> Tuple[dict, dict]: - """Parse relax module for the pipeline. - - Returns - ------- - info: dict - The info of parse. - report: dict - The report of parse. - """ - - info, report = {}, {} - for name, w_ctx in self._worker_ctxs.items(): - with msc_utils.change_workspace(w_ctx["workspace"]): - info[name], report[name] = w_ctx["worker"].parse() - return info, report - - def _tool_applied(self, tool_type: str) -> bool: - """Check if the tool is applied - - Parameters - ---------- - tool_type: str - The tool type. - - Returns - ------- - applied: bool - Whether the tool is applied. - """ - - return all(w["worker"].tool_applied(tool_type) for w in self._worker_ctxs.values()) - - def _apply_tool( - self, tool_type: str, knowledge: Optional[dict] = None, data_loader: Any = None - ) -> Tuple[dict, dict]: - """Apply tool with runner - - Parameters - ---------- - tool_type: str - The tool type to apply. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of apply tool. - report: dict - The report of apply tool. - """ - - if knowledge: - raise NotImplementedError("Apply tool with knowledge is not supported") - - self._jit.make_plan(tool_type, data_loader) - info, report = {}, {} - for name, w_ctx in self._worker_ctxs.items(): - with msc_utils.change_workspace(w_ctx["workspace"]): - info[name], report[name] = w_ctx["worker"].apply_tool(tool_type) - return info, report - - def _create_runtime( - self, - stage: str, - tools: Optional[List[str]] = None, - run_type: Optional[str] = None, - run_config: Optional[dict] = None, - visualize: bool = True, - profile: bool = True, - use_cache: bool = True, - ) -> Tuple[dict, dict]: - """Create runtime. - - Parameters - ---------- - stage: str - The pipeline stage. - tools: list - The tools to apply. - run_type: str - The type of runner. - run_config: dict - The config of runner. - visualize: bool - Whether to visualize the runner - profile: bool - Whether to profile the runner. - use_cache: bool - Whether to use cache. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - info, report = {}, {} - for name, w_ctx in self._worker_ctxs.items(): - with msc_utils.change_workspace(w_ctx["workspace"]): - info[name], report[name] = w_ctx["worker"].create_runner( - stage, tools, run_type, run_config, visualize, profile, use_cache - ) - self._jit.set_runner(name, w_ctx["worker"].runner) - return info, report - - def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: - """Export the model - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. - - Returns - ------- - exported: - The exported model. - """ - - if dump: - model = self.jit_cls.dump_nativate(self._model, folder, self._config[MSCStage.EXPORT]) - else: - model = self._model - worker_models = { - n: w["worker"].export_model(stage, folder.create_dir(n), dump) - for n, w in self._worker_ctxs.items() - } - return {"model": model, "worker_models": worker_models} - - def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the tool - - Parameters - ---------- - tool_type: str - The tool type. - folder: MSCDirectory - The export folder. - - Returns - ------- - configs: dict - The exported tool configs. - """ - - configs = {} - for name, w_ctx in self._worker_ctxs.items(): - with msc_utils.change_workspace(w_ctx["workspace"]): - configs[name] = w_ctx["worker"].export_tool(tool_type, folder.create_dir(name)) - assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) - return msc_utils.update_dict(self._tools_config[tool_type], {"worker_configs": configs}) - - def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the info of pipeline - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The info. - """ - - info = super()._export_info(stage, folder) - if stage in (MSCStage.OPTIMIZE, MSCStage.COMPILE): - info["worker_infos"] = {} - for name, w_ctx in self._worker_ctxs.items(): - with msc_utils.change_workspace(w_ctx["workspace"]): - info["worker_infos"][name] = w_ctx["worker"].export_info( - stage, folder.create_dir(name) - ) - return info - - def _destory(self): - """Destory the pipeline""" - - for w_ctx in self._worker_ctxs.values(): - w_ctx["worker"].destory() - - def get_runtime(self, ret_type: str = "runner") -> Any: - """Get the runtime of pipeline - - Parameters - ---------- - ret_type: str - The return type runner| runnable| model. - - Returns - ------- - runnable: - The runnable object. - """ - - if ret_type == "runner": - return self._jit - if ret_type in ("model", "runnable"): - return self._jit.jit_model - raise TypeError("Unexpect return type " + str(ret_type)) - - def pre_forward(self, runner_name: str, inputs: List[Tuple[str, Any]]) -> Any: - """pre forward hook for jit model - - Parameters - ---------- - runner_name: str - The runner name. - inputs: - The msc format inputs. - """ - - if self._current_stage == MSCStage.PREPARE: - cache = self._jit_caches.setdefault(runner_name, {}) - cache["inputs"] = inputs - self._pre_forward(runner_name, inputs) - - def _pre_forward(self, runner_name: str, inputs: List[Tuple[str, Any]]) -> Any: - """pre forward hook for jit model - - Parameters - ---------- - runner_name: str - The runner name. - inputs: - The msc format inputs. - """ - - return None - - def post_forward( - self, runner_name: str, outputs: List[Tuple[str, Any]] - ) -> List[Tuple[str, Any]]: - """pre forward hook for jit model - - Parameters - ---------- - runner_name: str - The runner name. - outputs: - The outputs. - - Returns - ------- - outputs: - The outputs. - """ - - if self._current_stage == MSCStage.PREPARE: - cache = self._jit_caches[runner_name] - assert "inputs" in cache, "Failed to record inputs" - if "saver" not in cache: - golden = ( - msc_utils.get_dataset_dir().create_dir(runner_name).relpath("Golden", False) - ) - saver_options = { - "input_names": [i[0] for i in cache["inputs"]], - "output_names": [o[0] for o in outputs], - } - cache["saver"] = msc_utils.IODataSaver(golden, saver_options) - cache["saver"].save_batch([i[1] for i in cache["inputs"]], [o[1] for o in outputs]) - return self._post_forward(runner_name, outputs) - - def _post_forward( - self, runner_name: str, outputs: List[Tuple[str, Any]] - ) -> List[Tuple[str, Any]]: - """pre forward hook for jit model - - Parameters - ---------- - runner_name: str - The runner name. - outputs: - The outputs. - - Returns - ------- - outputs: - The outputs. - """ - - return outputs - - def _record_stage(self, stage: str, info: Optional[dict] = None, report: Optional[dict] = None): - """Record the stage - - Parameters - ------- - stage: str - The compile stage - info: dict - The info of stage. - report: dict - The report of stage. - """ - - stage_report = {} - for name, w_report in report.items(): - for k, v in w_report.items(): - stage_report.setdefault(k, {})[name] = v - info = {k: v for k, v in info.items() if v} - super()._record_stage(stage, info, stage_report) - - def pipe_mark(self, msg: Any) -> str: - """Mark the message with pipeline info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return "DYNAMIC " + str(msg) - - @property - def jit_cls(self): - return BaseJIT - - @property - def worker_cls(self): - return MSCPipeWorker - - -class TorchDynamic(MSCDynamic): - """Dynamic of Pipeline, process torch dynamo""" - - @property - def jit_cls(self): - # pylint: disable=import-outside-toplevel - from tvm.contrib.msc.framework.torch.runtime import TorchJIT - - return TorchJIT diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py deleted file mode 100644 index b997f124b246..000000000000 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ /dev/null @@ -1,288 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.pipeline.manager""" - -from typing import Any, List, Optional, Tuple - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.gym.control import create_controller -from tvm.contrib.msc.core.utils.message import MSCStage - -from .pipeline import BasePipeline -from .worker import MSCPipeWorker - - -class MSCManager(BasePipeline): - """Manager of Pipeline, process static model""" - - def setup(self) -> dict: - """Setup the pipeline - - Returns - ------- - info: dict - The setup info. - """ - - self._worker = self.create_worker(self._model, "main") - self._config = self._worker._config - return super().setup() - - def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: - """Prepare datas for the pipeline. - - Parameters - ---------- - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of prepare. - report: dict - The report of prepare. - """ - - return self._worker.prepare(data_loader) - - def _parse(self) -> Tuple[dict, dict]: - """Parse relax module for the pipeline. - - Returns - ------- - info: dict - The info of parse. - report: dict - The report of parse. - """ - - return self._worker.parse() - - def _tool_applied(self, tool_type: str) -> bool: - """Check if the tool is applied - - Parameters - ---------- - tool_type: str - The tool type. - - Returns - ------- - applied: bool - Whether the tool is applied. - """ - - return self._worker.tool_applied(tool_type) - - def _apply_tool( - self, tool_type: str, knowledge: Optional[dict] = None, data_loader: Any = None - ) -> Tuple[dict, dict]: - """Apply tool with runner - - Parameters - ---------- - tool_type: str - The tool type to apply. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of apply tool. - report: dict - The report of apply tool. - """ - - return self._worker.apply_tool(tool_type, knowledge, data_loader) - - def _create_runtime( - self, - stage: str, - tools: Optional[List[str]] = None, - run_type: Optional[str] = None, - run_config: Optional[dict] = None, - visualize: bool = True, - profile: bool = True, - use_cache: bool = True, - ) -> Tuple[dict, dict]: - """Create runtime. - - Parameters - ---------- - stage: str - The pipeline stage. - tools: list - The tools to apply. - run_type: str - The type of runner. - run_config: dict - The config of runner. - visualize: bool - Whether to visualize the runner - profile: bool - Whether to profile the runner. - use_cache: bool - Whether to use cache. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - return self._worker.create_runner( - stage, tools, run_type, run_config, visualize, profile, use_cache - ) - - def _run_gym(self, stage: str, config: dict, knowledge: dict, data_loader: Any) -> dict: - """Run gym. - - Parameters - ---------- - stage: str - The pipeline stage. - config: dict - The gym config. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - knowledge: dict - The learned knowledge. - """ - - extra_config = { - "env": { - "runner": self._worker.runner, - "data_loader": data_loader, - "knowledge": knowledge, - }, - "verbose": self._verbose, - } - controller = create_controller(stage, config, extra_config) - return controller.run() - - def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: - """Export the model - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. - - Returns - ------- - exported: - The exported model. - """ - - return self._worker.export_model(stage, folder, dump) - - def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the tool - - Parameters - ---------- - tool_type: str - The tool type. - folder: MSCDirectory - The export folder. - - Returns - ------- - config: dict - The exported tool config. - """ - - assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) - exp_config = {"tool_config": self._worker.export_tool(tool_type, folder)} - return msc_utils.update_dict(self._tools_config[tool_type], exp_config) - - def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the info of pipeline - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The info. - """ - - info = super()._export_info(stage, folder) - if stage in (MSCStage.OPTIMIZE, MSCStage.COMPILE): - info.update(self._worker.export_info(stage, folder)) - return info - - def _destory(self): - """Destory the pipeline""" - - self._worker.destory() - - def get_runtime(self, ret_type: str = "runner") -> Any: - """Get the runtime of pipeline - - Parameters - ---------- - ret_type: str - The return type runner| runnable| model. - - Returns - ------- - runnable: - The runnable object. - """ - - return self._worker.get_runnable(ret_type) - - def pipe_mark(self, msg: Any) -> str: - """Mark the message with pipeline info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return "MANAGER " + str(msg) - - @property - def worker_cls(self): - return MSCPipeWorker diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py b/python/tvm/contrib/msc/pipeline/pipeline.py deleted file mode 100644 index 132a2d51d75d..000000000000 --- a/python/tvm/contrib/msc/pipeline/pipeline.py +++ /dev/null @@ -1,854 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -"""tvm.contrib.msc.pipeline.pipeline""" - -import json -import os -import traceback -from typing import Any, List, Optional, Tuple, Union - -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import BaseTool, get_tool_cls -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCKey, MSCMap -from tvm.contrib.msc.plugin.utils import export_plugins, load_plugins - -from .utils import get_tool_stage, map_tools, support_tool -from .worker import BasePipeWorker - - -class BasePipeline: - """Base Pipeline of MSC - - Parameters - ---------- - model: Any - The raw model in framwork. - config: dict - The config for pipeline. - plugins: dict - The plugins for pipeline. - run_optimize: bool - Whether to run optimize. - run_compile: bool - Whether to run compile. - root: str - The root path for files. - """ - - def __init__( - self, - model: Any, - config: dict, - plugins: Optional[dict] = None, - run_optimize: bool = True, - run_compile: bool = True, - root: Optional[str] = None, - ): - # change path to root path - if root: - - def _from_root_mark(val): - if isinstance(val, str) and MSCKey.ROOT_MARK in val: - return val.replace(MSCKey.ROOT_MARK, root) - return val - - if isinstance(model, dict): - model = msc_utils.map_dict(model, _from_root_mark) - elif isinstance(model, str): - model = _from_root_mark(model) - config = msc_utils.map_dict(config, _from_root_mark) - plugins = msc_utils.map_dict(plugins, _from_root_mark) - - MSCMap.reset() - self._model, self._meta_config = model, config - self._config = msc_utils.copy_dict(config) - if not run_optimize and MSCStage.OPTIMIZE in self._config: - self._config.pop(MSCStage.OPTIMIZE) - if not run_compile and MSCStage.COMPILE in self._config: - self._config.pop(MSCStage.COMPILE) - for stage in [MSCStage.PREPARE, MSCStage.PARSE, MSCStage.EXPORT]: - self._config.setdefault(stage, {}) - self._verbose = self._config.get("verbose", "info") - use_cache = self._config.get("use_cache", True) - if "workspace" in self._config: - self._workspace = msc_utils.set_workspace(self._config.pop("workspace"), use_cache) - else: - self._workspace = msc_utils.set_workspace("msc_workspace", use_cache) - if "logger" in self._config: - self._logger = self._config.pop("logger") - MSCMap.set(MSCKey.GLOBALE_LOGGER, self._logger) - else: - if "log_file" in self._config: - log_file = self._config.pop("log_file") - else: - log_file = self._workspace.relpath("MSC_LOG", keep_history=False) - self._logger = msc_utils.set_global_logger(self._verbose, log_file) - self._plugins = load_plugins(plugins) if plugins else {} - self.change_stage(MSCStage.SETUP) - self._logger.info(msc_utils.msg_block(self.pipe_mark("SETUP"), self.setup())) - - def setup(self) -> dict: - """Setup the pipeline - - Returns - ------- - info: dict - The setup info. - """ - - # define run type - self._model_type = self._config["model_type"] - self._optimize_type = self._config.get(MSCStage.OPTIMIZE, {}).get( - "run_type", self._model_type - ) - self._compile_type = self._config.get(MSCStage.COMPILE, {}).get( - "run_type", self._model_type - ) - self._optimized, self._compiled = False, False - - # map tools - self._tools_config = map_tools(self._config.get("tools", [])) - - # register plugins - if self._plugins: - for t in [self._model_type, self._optimize_type, self._compile_type]: - assert t in self._plugins, f"Missing plugin for {t}" - for name, plugin in self._plugins[self._model_type].get_ops_info().items(): - _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) - - # status - self._current_stage = None - self._report = { - "success": False, - "info": {}, - "duration": {}, - } - return { - "workspace": self._workspace.path, - "log_file": msc_utils.get_log_file(self._logger), - "verbose": self._verbose, - "plugins": self._plugins, - "config": self._config, - } - - def run_pipe(self) -> dict: - """Run the pipeline and return object. - - Returns - ------- - report: - The pipeline report. - """ - - err_msg, err_info = None, None - try: - self.prepare() - self.parse() - if MSCStage.BASELINE in self._config: - self.baseline() - if MSCStage.OPTIMIZE in self._config: - self.optimize() - if MSCStage.COMPILE in self._config: - self.compile() - except Exception as exc: # pylint: disable=broad-exception-caught - err_msg = "Pipeline failed: " + str(exc) - err_info = traceback.format_exc() - self.summary(err_msg, err_info) - self._logger.info(msc_utils.msg_block(self.pipe_mark("SUMMARY"), self._report, 0)) - self._workspace.finalize() - return self._report - - def change_stage(self, stage: str, log_stage: bool = True) -> str: - """Change stage - - Parameters - ---------- - stage: str - The stage name. - log_stage: bool - Whether to log the stage. - - Returns - ------- - stage: str - The stage name. - """ - - self._current_stage = stage - msc_utils.time_stamp(stage, log_stage) - return stage - - def prepare(self): - """Prepare datas for the pipeline.""" - - self.change_stage(MSCStage.PREPARE) - info, report = self._prepare(self._get_loader(MSCStage.PREPARE)) - self._record_stage(MSCStage.PREPARE, info, report) - - def _prepare(self, data_loader: Any) -> Tuple[dict, dict]: - """Prepare datas for the pipeline. - - Parameters - ---------- - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of prepare. - report: dict - The report of prepare. - """ - - raise NotImplementedError("_prepare is not implemented in " + str(self.__class__)) - - def parse(self): - """Parse relax module for the pipeline.""" - - self.change_stage(MSCStage.PARSE) - info, report = self._parse() - self._record_stage(MSCStage.PARSE, info, report) - - def _parse(self) -> Tuple[dict, dict]: - """Parse relax module for the pipeline. - - Returns - ------- - info: dict - The info of parse. - report: dict - The report of parse. - """ - - raise NotImplementedError("_parse is not implemented in " + str(self.__class__)) - - def baseline(self): - """Run the baseline.""" - - self._run_stage(MSCStage.BASELINE) - - def optimize(self) -> Tuple[dict, dict]: - """Run the optimize. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - self._run_stage(MSCStage.OPTIMIZE) - self._optimized = True - - def compile(self) -> Tuple[dict, dict]: - """Run the compile. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - self._run_stage(MSCStage.COMPILE) - self._compiled = True - - def _run_stage(self, stage: str) -> Tuple[dict, dict]: - """Run the stage. - - Parameters - ---------- - stage: str - The pipeline stage. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - self.change_stage(stage) - tools = [] - for tool in self._config.get("tools", []): - run_type = tool.get("run_type", self._config[stage]["run_type"]) - if not support_tool(tool, stage, run_type): - continue - tools.append(tool["tool_type"]) - tool_cls, tool_stage = ( - self.get_tool_cls(tool, run_type), - get_tool_stage(tool["tool_type"]), - ) - t_stage = self.change_stage(stage + "." + tool_stage) - if self._tool_applied(tool["tool_type"]): - if tool_cls.apply_once(): - msg = "Remove apply once tool " + str(tool["tool_type"]) - self._logger.info(self.pipe_mark(msg)) - tools = tools[:-1] - else: - self._logger.info(self.pipe_mark("Apply planed tool " + str(tool["tool_type"]))) - continue - self.change_stage(t_stage + ".build", False) - info, report = self._create_runtime( - t_stage, tools, run_type=run_type, visualize=False, profile=False, use_cache=False - ) - self._record_stage(t_stage, info, report) - knowledge, loader = None, self._get_loader(tool_stage) - if "gym_configs" in tool: - for idx, config in enumerate(tool["gym_configs"]): - knowledge_file = self._workspace.create_dir("Gym").relpath( - f"knowledge_{idx}.json" - ) - gym_mark = "GYM[{}/{}]({} @ {}) ".format( - idx, len(tool["gym_configs"]), self._config[stage]["run_type"], tool_stage - ) - if os.path.isfile(knowledge_file): - knowledge = knowledge_file - msg = f"{gym_mark}Load from {knowledge}" - self._logger.info(self.pipe_mark(msg)) - else: - self.change_stage(tool_stage + f".gym_{idx}") - self._logger.info(self.pipe_mark(gym_mark + "Start search")) - knowledge = self._run_gym(tool_stage, config, knowledge, loader) - msc_utils.save_dict(knowledge, knowledge_file) - knowledge = msc_utils.load_dict(knowledge) - self.change_stage(t_stage + ".apply", False) - info, report = self._apply_tool(tool["tool_type"], knowledge, loader) - self._record_stage(t_stage, info, report) - if tool_cls.apply_once(): - msg = "Remove apply once tool " + str(tool["tool_type"]) - self._logger.info(self.pipe_mark(msg)) - tools = tools[:-1] - self.change_stage(stage + ".build", False) - info, report = self._create_runtime(stage, tools) - self._record_stage(stage, info, report) - - def _tool_applied(self, tool_type: str) -> bool: - """Check if the tool is applied - - Parameters - ---------- - tool_type: str - The tool type. - - Returns - ------- - applied: bool - Whether the tool is applied. - """ - - return False - - def _apply_tool( - self, tool_type: str, knowledge: Optional[dict] = None, data_loader: Any = None - ) -> Tuple[dict, dict]: - """Apply tool with runner - - Parameters - ---------- - tool_type: str - The tool type to apply. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of apply tool. - report: dict - The report of apply tool. - """ - - raise NotImplementedError("_apply_tool is not implemented in " + str(self.__class__)) - - def _create_runtime( - self, - stage: str, - tools: Optional[List[str]] = None, - run_type: Optional[str] = None, - run_config: Optional[dict] = None, - visualize: bool = True, - profile: bool = True, - use_cache: bool = True, - ) -> Tuple[dict, dict]: - """Create runtime. - - Parameters - ---------- - stage: str - The pipeline stage. - tools: list - The tools to apply. - run_type: str - The type of runner. - run_config: dict - The config of runner. - visualize: bool - Whether to visualize the runner - profile: bool - Whether to profile the runner. - use_cache: bool - Whether to use cache. - - Returns - ------- - info: dict - The info of stage. - report: dict - The report of stage. - """ - - raise NotImplementedError("_create_runtime is not implemented in " + str(self.__class__)) - - def _run_gym(self, stage: str, config: dict, knowledge: dict, data_loader: Any) -> dict: - """Run gym. - - Parameters - ---------- - stage: str - The pipeline stage. - config: dict - The gym config. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - knowledge: dict - The learned knowledge. - """ - - raise NotImplementedError("_run_gym is not implemented in " + str(self.__class__)) - - def summary(self, err_msg: Optional[str] = None, err_info: Optional[str] = None) -> dict: - """Summary the pipeline. - - Parameters - ---------- - err_msg: str - The error message. - err_info: str - The error info. - - Returns - ------- - report: dict - The report of the pipeline. - """ - - self.change_stage(MSCStage.SUMMARY, False) - if err_msg: - self._report.update({"success": False, "err_msg": err_msg, "err_info": err_info}) - else: - self._report["success"] = True - self._report["duration"] = msc_utils.get_duration() - return self._report - - def export(self, path: Optional[str] = None, dump: bool = True) -> Union[str, dict]: - """Export the pipeline - - Parameters - ---------- - path: str - The export path. - dump: bool - Whether to dump the info. - - Returns - ------- - export_path/pipeline: str/dict - The exported path/pipeline info. - """ - - path = path or "msc_export" - if path.endswith(".tar.gz"): - folder, dump = msc_utils.msc_dir(path.replace(".tar.gz", ""), keep_history=False), True - else: - folder = msc_utils.msc_dir(path, keep_history=False) - - if self._compiled: - stage = MSCStage.COMPILE - elif self._optimized: - stage = MSCStage.OPTIMIZE - else: - stage = MSCStage.SETUP - - def _to_root_mark(val): - if isinstance(val, str) and folder.path != val and folder.path in val: - return val.replace(folder.path, MSCKey.ROOT_MARK) - return val - - def _export_plugins(folder: msc_utils.MSCDirectory): - if self._compiled: - if dump and self.compile_type in self._plugins: - return self._plugins[self.compile_type].copy_libs(folder) - return self._plugins.get(self.compile_type) - if dump: - return export_plugins(self._plugins, folder) - return self._plugins - - export = { - "logger": folder.copy(msc_utils.get_log_file(self._logger)), - "report": self._report, - "info": self._export_info(stage, folder.create_dir("info")), - "model": self._export_model(stage, folder.create_dir("model"), dump), - "plugins": _export_plugins(folder.create_dir("plugins")), - } - if self._compiled: - # save golden - num_golden = self._config[MSCStage.EXPORT].get("num_golden", 5) - if num_golden > 0: - saver_options = { - "input_names": [i[0] for i in self._config["inputs"]], - "output_names": self._config["outputs"], - } - batch_cnt, export["golden"] = 0, folder.create_dir("golden").path - with msc_utils.IODataSaver(export["golden"], saver_options) as saver: - for inputs in self._get_loader()(): - if batch_cnt >= num_golden: - break - batch_cnt = saver.save_batch(inputs, self.get_runtime().run(inputs)) - else: - export["config"] = self.export_config(folder, dump) - export = msc_utils.map_dict(export, _to_root_mark) - if not dump: - return export - with open(folder.relpath("export.json"), "w") as f: - f.write(json.dumps(export, indent=2)) - folder.finalize() - if path.endswith(".tar.gz"): - msc_utils.pack_folder(path.replace(".tar.gz", ""), "tar.gz") - return path - - def export_config(self, folder: msc_utils.MSCDirectory, dump: bool = True) -> dict: - """Export the config - - Parameters - ---------- - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. - - Returns - ------- - config: dict - The updated config. - """ - - # dump the dataloader - def _export_dataset(name, info, dump: bool): - loader, max_batch = info["loader"], info.get("max_batch", -1) - data_folder = folder.create_dir("dataset") - if isinstance(loader, str) and msc_utils.is_callable(loader): - path, func_name = loader.split(":") - exp_loader = data_folder.copy(path) + ":" + func_name - elif msc_utils.is_io_dataset(loader): - exp_loader = data_folder.copy(loader, name) - elif callable(loader) and dump: - saver_options = {"input_names": [i[0] for i in self._config["inputs"]]} - batch_cnt, exp_loader = 0, data_folder.create_dir(name).path - with msc_utils.IODataSaver(exp_loader, saver_options) as saver: - for inputs in loader(): - if batch_cnt >= max_batch > 0: - break - batch_cnt = saver.save_batch(inputs) - else: - exp_loader = loader - return {"loader": exp_loader, "max_batch": max_batch} - - config = msc_utils.copy_dict(self._meta_config) - config["dataset"] = { - k: _export_dataset(k, v, dump) for k, v in self._config["dataset"].items() - } - if self._optimized: - config["model_type"] = MSCFramework.TVM - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE]: - if stage in config: - config.pop(stage) - if "profile" in config[MSCStage.COMPILE] and self.get_runtime().trained: - config[MSCStage.COMPILE]["profile"].setdefault("check", {})["err_rate"] = -1 - config["tools"] = [] - for tool in self._config.get("tools", []): - tool_type = tool["tool_type"] - skip_msg = "Skip export tool " + tool_type - if not support_tool(tool, MSCStage.COMPILE, self._compile_type): - self._logger.info(self.pipe_mark(skip_msg + "(unsupported)")) - continue - tool_cls = self.get_tool_cls(tool, self._optimize_type) - if not tool_cls.exportable(): - self._logger.info(self.pipe_mark(skip_msg + "(unexportable)")) - continue - config["tools"].append(self._export_tool(tool_type, folder)) - # remove not serializable items - if dump: - remove_keys = {"workspace", "logger"} - config = {k: v for k, v in config.items() if k not in remove_keys} - return config - - def _export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: - """Export the model - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. - - Returns - ------- - exported: - The exported model. - """ - - raise NotImplementedError("_export_model is not implemented in " + str(self.__class__)) - - def _export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the tool - - Parameters - ---------- - tool_type: str - The tool type. - folder: MSCDirectory - The export folder. - - Returns - ------- - tool: dict - The exported tool. - """ - - raise NotImplementedError("_export_tool is not implemented in " + str(self.__class__)) - - def _export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the info of pipeline - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The info. - """ - - return {} - - def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: - """Get the data loader""" - - config = self._config["dataset"].get(name, self._config["dataset"][MSCStage.PREPARE]) - source_loader = config.get("loader") - assert source_loader, "Dataset loader should be given for msc pipeline" - if source_loader == "from_random": - max_batch = config.get("max_batch", 5) - - def get_random(): - def _to_data(inp): - shape = [1 if isinstance(d, str) else d for d in inp[1]] - return msc_utils.random_data([shape, inp[2]]) - - for _ in range(max_batch): - yield {i[0]: _to_data(i) for i in self._config["inputs"]} - - loader, source_type = get_random, "random" - elif isinstance(source_loader, dict): - - def load_data(): - return [source_loader] - - loader, source_type = load_data, "dict" - elif msc_utils.is_io_dataset(source_loader): - max_batch = config.get("max_batch", -1) - - def load_datas(): - for inputs, _ in msc_utils.IODataLoader(source_loader, end=max_batch): - yield inputs - - loader, source_type = load_datas, "io_data" - elif callable(source_loader): - max_batch = config.get("max_batch", -1) - load_kwargs = config.get("load_kwargs", {}) - if max_batch == -1 and not load_kwargs: - loader, source_type = source_loader, "custom" - else: - - def get_source(): - for idx, inputs in enumerate(source_loader(**load_kwargs)): - if idx >= max_batch > 0: - break - yield inputs - - loader, source_type = get_source, "loaded_custom" - else: - raise TypeError(f"Unexpected source loader {source_loader}({type(source_loader)})") - msg = f"Create data loader({name}) {loader.__name__}({source_type})" - self._logger.debug(self.pipe_mark(msg)) - return loader - - def _record_stage(self, stage: str, info: Optional[dict] = None, report: Optional[dict] = None): - """Record the stage - - Parameters - ------- - stage: str - The compile stage - info: dict - The info of stage. - report: dict - The report of stage. - """ - - if info: - self._logger.info(msc_utils.msg_block(self.pipe_mark(stage.upper()), info)) - if report: - self._report["info"].setdefault(stage, {}).update(report) - - def destory(self, keep_workspace: bool = False): - """Destroy the pipeline - - Parameters - ---------- - keep_workspace: bool - Whether to keep workspace. - """ - - self._destory() - if not keep_workspace: - self._workspace.destory() - msc_utils.remove_loggers() - - def _destory(self): - """Destroy the pipeline.""" - - raise NotImplementedError("_destory is not implemented in " + str(self.__class__)) - - def get_tool_cls(self, tool: dict, framework: str) -> BaseTool: - """Get the tool class from tool config - - Parameters - ---------- - tool: dict - The tool config. - framework: str - The framework. - - Returns - ------- - tool_cls: - The tool class. - """ - - return get_tool_cls(framework, tool["tool_type"], tool["tool_config"]) - - def get_runtime(self, ret_type: str = "runner") -> Any: - """Get the runtime of pipeline - - Parameters - ---------- - ret_type: str - The return type runner| runnable| model. - - Returns - ------- - runnable: - The runnable object. - """ - - raise NotImplementedError("get_runtime is not implemented in " + str(self.__class__)) - - def create_worker(self, model: Any, name: str, config: Optional[dict] = None): - """Create pipe worker - - Parameters - ------- - model: Any - The raw model in framwork. - name: str - The name of worker. - worker_config: dict - The extra config for worker. - - Returns - ------- - worker: str - The message with mark. - """ - - return self.worker_cls( - model, - config or self._config, - self._workspace, - self._plugins, - self._logger, - name=name, - ) - - def pipe_mark(self, msg: Any) -> str: - """Mark the message with pipeline info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return "PIPE " + str(msg) - - @property - def worker_cls(self): - return BasePipeWorker - - @property - def report(self): - return self._report - - @property - def model_type(self): - return self._model_type - - @property - def optimize_type(self): - return self._optimize_type - - @property - def compile_type(self): - return self._compile_type diff --git a/python/tvm/contrib/msc/pipeline/utils.py b/python/tvm/contrib/msc/pipeline/utils.py deleted file mode 100644 index e162ef89fd46..000000000000 --- a/python/tvm/contrib/msc/pipeline/utils.py +++ /dev/null @@ -1,231 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.pipeline.config""" - -import copy -from typing import Dict, List, Optional, Tuple, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.message import MSCStage - - -def get_tool_stage(tool_type: str) -> str: - """Map the stage according to tool_type - - Parameters - ---------- - tool_type: str - The tool type. - - Returns - ------- - stage: str - The stage. - """ - - if tool_type == ToolType.PRUNER: - return MSCStage.PRUNE - if tool_type == ToolType.QUANTIZER: - return MSCStage.QUANTIZE - if tool_type == ToolType.DISTILLER: - return MSCStage.DISTILL - if tool_type == ToolType.TRACKER: - return MSCStage.TRACK - return tool_type - - -def map_tools(tools: List[dict]) -> dict: - """Map tools from list - - Parameters - ---------- - tools: list - The tools config, - - Returns - ------- - tools: dict - The tools map. - """ - - tools_map = {t["tool_type"]: t for t in tools} - assert len(tools_map) == len(tools), "Duplicate tools: " + str([t["tool_type"] for t in tools]) - return tools_map - - -def support_tool(tool: dict, stage: str, run_type: str) -> bool: - """Check if the tool is supported - - Parameters - ---------- - tool: dict - The tool config, - stage: str - The pipeline stage. - run_type: str - The runtime type. - - Returns - ------- - supported: bool - Whether the tool is supported. - """ - - run_type = tool.get("run_type", run_type) - if stage == MSCStage.BASELINE: - return tool["tool_type"] == ToolType.TRACKER - return True - - -def config_tool(tool_type: str, raw_config: Union[dict, str]) -> dict: - """Config the tool - - Parameters - ---------- - tool_type: str - The tool type, - raw_config: str| dict - The tool config or style. - - Returns - ------- - config: dict - The config for tool. - """ - - if isinstance(raw_config, dict): - if "config_style" in raw_config: - config_style = raw_config.pop("config_style") - else: - config_style = "default" - else: - config_style, raw_config = raw_config, None - configer_cls = msc_utils.get_registered_tool_configer(tool_type, config_style) - assert configer_cls, f"Can not find configer for {tool_type}:{config_style}" - return {"tool_type": tool_type, **configer_cls().config(raw_config)} - - -def create_config( - inputs: List[dict], - outputs: List[str], - model_type: str, - baseline_type: Optional[str] = None, - optimize_type: Optional[str] = None, - compile_type: Optional[str] = None, - dataset: Optional[Dict[str, dict]] = None, - tools: Optional[List[Tuple[str, Union[dict, str]]]] = None, - dynamic: bool = False, - run_config: Optional[Dict[str, dict]] = None, - skip_config: Optional[Dict[str, str]] = None, - **extra_config, -) -> dict: - """Create config for msc pipeline - - Parameters - ---------- - inputs: list - The inputs info, - outputs: list - The output names. - model_type: str - The model type. - baseline_type: str - The baseline type. - compile_type: str - The compile type. - optimize_type: str - The optimize type. - dataset: dict - The datasets for compile pipeline. - tools: list - The tools config. - dynamic: bool - Whether to config dyanmic mode. - skip_config: dict - The skip config for compile. - extra_config: dict - The extra config. - """ - - all_stages = [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE] - baseline_type = baseline_type or model_type - optimize_type = optimize_type or baseline_type - compile_type = compile_type or optimize_type - tools = tools or [] - tools = [config_tool(t_type, t_config) for t_type, t_config in tools] - extra_config = extra_config or {} - # basic config - config = { - "model_type": model_type, - "dynamic": dynamic, - "inputs": inputs, - "outputs": outputs, - "dataset": dataset, - "tools": tools, - MSCStage.PREPARE: {"profile": {"benchmark": {"repeat": -1}}}, - MSCStage.BASELINE: { - "run_type": baseline_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - }, - } - - # config optimize - opt_tools = [t for t in tools if support_tool(t, MSCStage.OPTIMIZE, optimize_type)] - if opt_tools: - config[MSCStage.OPTIMIZE] = { - "run_type": optimize_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - } - - # config compile - config[MSCStage.COMPILE] = { - "run_type": compile_type, - "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, - } - - # update run config - if run_config: - if "all" in run_config: - all_config = run_config.pop("all") - run_config.update({s: copy.deepcopy(all_config) for s in all_stages}) - for stage, r_config in run_config.items(): - extra_config.setdefault(stage, {}).setdefault("run_config", {}).update(r_config) - - # update config - if extra_config: - config = msc_utils.update_dict(config, extra_config) - - # skip stages - if skip_config: - if "all" in run_config: - all_config = skip_config.pop("all") - skip_config.update({s: copy.deepcopy(all_config) for s in all_stages}) - for stage, s_type in skip_config.items(): - if stage not in config: - continue - if s_type == "stage": - config.pop(stage) - elif s_type == "profile": - config[stage].pop("profile") - elif s_type == "check": - config[stage]["profile"]["check"]["err_rate"] = -1 - elif s_type == "benchmark": - config[stage]["profile"].pop("benchmark") - else: - raise TypeError("Unexpected skip type " + str(s_type)) - return config diff --git a/python/tvm/contrib/msc/pipeline/worker.py b/python/tvm/contrib/msc/pipeline/worker.py deleted file mode 100644 index bdac63170626..000000000000 --- a/python/tvm/contrib/msc/pipeline/worker.py +++ /dev/null @@ -1,786 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=import-outside-toplevel, unused-argument -# ruff: noqa: E501 -"""tvm.contrib.msc.pipeline.worker""" - -import logging -import os -import time -from typing import Any, List, Optional, Tuple - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.runtime import BaseRunner -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework - -from .utils import get_tool_stage, map_tools, support_tool - - -class BasePipeWorker: - """Base Worker of MSC pipeline - - Parameters - ---------- - model: Any - The raw model in framwork. - config: dict - The config for pipeline. - workspace: MSCDirectory - The workspace. - plugins: dict - The plugins for pipeline. - run_optimize: bool - Whether to run optimize. - run_compile: bool - Whether to run compile. - logger: logging.Logger - The logger. - name: str - The name of the worker. - """ - - def __init__( - self, - model: Any, - config: dict, - workspace: msc_utils.MSCDirectory, - plugins: Optional[dict] = None, - logger: Optional[logging.Logger] = None, - name: str = "main", - ): - # check/set default stage - for key in ["inputs", "outputs", "dataset"]: - assert key in config, f"Missing {key} in config" - - self._config = msc_utils.copy_dict(config) - self._workspace = workspace - self._plugins = plugins - self._model_type = config["model_type"] - self._optimize_type = config.get(MSCStage.OPTIMIZE, {}).get("run_type", self._model_type) - self._compile_type = config.get(MSCStage.COMPILE, {}).get("run_type", self._model_type) - runner_cls = self._get_runner_cls(self._model_type) - self._model, self._device, self._training = runner_cls.load_native(model, config) - self._verbose = config.get("verbose", "info") - self._logger = logger or msc_utils.get_global_logger() - self._name = name - self._optimized, self._compiled = False, False - self.setup() - - def setup(self) -> dict: - """Setup the manager - - Returns - ------- - config: dict - The updated config. - """ - - self._debug_levels = self.update_config() - self._tools_config = map_tools(self._config.get("tools", [])) - self._relax_mod, self._sample_inputs = None, None - self._runner = None - - def update_config(self) -> dict: - """Update config - - Returns - ------- - debug_levels: dict - The debug_levels. - """ - - debug_levels = {} - self._config = self._get_runner_cls(self._model_type).update_config( - MSCStage.PARSE, self._config, self._model - ) - - # update runner config - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in self._config: - continue - if "run_type" not in self._config[stage]: - self._config[stage]["run_type"] = self._model_type - runner_cls = self._get_runner_cls(self._config[stage]["run_type"]) - self._config = runner_cls.update_config(stage, self._config, self._model) - - # update tool config - if self._config.get("tools"): - self._config["tools"] = self._update_tools_config(self._config["tools"]) - - # update export config - self._config[MSCStage.EXPORT].update( - {"inputs": self._config["inputs"], "outputs": self._config["outputs"]} - ) - - def _set_debug_level(stage: str, sub_config: dict, default: Optional[int] = None) -> dict: - if "debug_level" in sub_config: - debug_levels[stage] = sub_config["debug_level"] - elif default is not None: - debug_levels[stage] = default - sub_config["debug_level"] = default - return debug_levels - - if self._verbose.startswith("debug:"): - debug_level = int(self._verbose.split(":")[1]) - else: - debug_level = 0 - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in self._config: - continue - debug_levels = _set_debug_level(stage, self._config[stage]["run_config"], debug_level) - for t_config in self._config.get("tools", []): - if not support_tool(t_config, stage, self._config[stage]["run_type"]): - continue - t_stage = stage + "." + get_tool_stage(t_config["tool_type"]) - debug_levels = _set_debug_level(t_stage, t_config["tool_config"], debug_level) - ordered_keys = [ - "model_type", - "inputs", - "outputs", - "dataset", - "tools", - MSCStage.PREPARE, - MSCStage.PARSE, - MSCStage.BASELINE, - MSCStage.OPTIMIZE, - MSCStage.COMPILE, - MSCStage.EXPORT, - ] - self._config = {k: self._config[k] for k in ordered_keys if k in self._config} - return debug_levels - - def _update_tools_config(self, tools: List[dict]) -> List[dict]: - """Update tool in stage config. - - Parameters - ---------- - tools: list - The config of tools. - - Returns - ------- - tools: list - The updated config of tools. - """ - - for tool in tools: - tool_config = tool["tool_config"] - if "plan_file" not in tool_config: - tool_config["plan_file"] = "msc_{}.json".format(tool["tool_type"]) - tool_config["plan_file"] = msc_utils.to_abs_path( - tool_config["plan_file"], msc_utils.get_config_dir() - ) - return tools - - def prepare(self, data_loader: Any = None) -> Tuple[dict, dict]: - """Prepare datas for the pipeline. - - Parameters - ---------- - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of prepare. - report: dict - The report of prepare. - """ - - stage_config = self._config[MSCStage.PREPARE] - use_cache = self._config.get("use_cache", True) - runner_cls = self._get_runner_cls(self._model_type) - run_func = runner_cls.run_native if hasattr(runner_cls, "run_native") else None - input_names = [i[0] for i in self._config["inputs"]] - - # create golden - if "golden" in self._config["dataset"]: - golden_folder = self._config["dataset"]["golden"]["loader"] - else: - golden_folder = msc_utils.get_dataset_dir().relpath("Golden", use_cache) - if msc_utils.is_io_dataset(golden_folder): - loader, source_type = msc_utils.IODataLoader(golden_folder), "cache" - self._sample_inputs = loader[0][0] - datas_info = loader.info - msg = f"Load {len(loader)} golden from {golden_folder}" - self._logger.debug(self.worker_mark(msg)) - elif run_func: - source_type = "native" - saver_options = {"input_names": input_names, "output_names": self._config["outputs"]} - cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) - with msc_utils.IODataSaver(golden_folder, saver_options) as saver: - for inputs in data_loader(): - if cnt >= max_golden > 0: - break - if not self._sample_inputs: - self._sample_inputs = { - k: msc_utils.cast_array(v) for k, v in inputs.items() - } - try: - outputs, _ = run_func( - self._model, inputs, input_names, self._config["outputs"] - ) - except Exception as exc: # pylint: disable=broad-exception-caught - if cnt == 0: - msg = f"Failed to test native: {exc}" - self._logger.warning(self.worker_mark(msg)) - outputs = None - cnt = saver.save_batch(inputs, outputs) - datas_info = saver.info - msg = f"Save {cnt} golden to {golden_folder}" - self._logger.debug(self.worker_mark(msg)) - else: - raise Exception("golden_folder or runner should given to save golden") - self._config["dataset"]["golden"] = {"loader": golden_folder, "max_batch": -1} - - def _to_abstract(info: dict) -> dict: - def _to_tensor_str(info): - return "{},{}".format(";".join([str(s) for s in info["shape"]]), info["dtype"]) - - return { - "num_datas": info["num_datas"], - "inputs": {n: _to_tensor_str(i) for n, i in info["inputs"].items()}, - "outputs": {n: _to_tensor_str(o) for n, o in info["outputs"].items()}, - } - - info = { - f"golden_folder({source_type})": golden_folder, - "datas_info": _to_abstract(datas_info), - "smaple_inputs": self._sample_inputs, - } - - # profile - report = {} - if "profile" in stage_config and run_func: - benchmark = stage_config["profile"].get("benchmark", {}) - benchmark["repeat"] = self._get_repeat(benchmark) - try: - _, avg_time = run_func( - self._model, - self._sample_inputs, - input_names, - self._config["outputs"], - **benchmark, - ) - latency = f"{avg_time:.2f} ms @ {self._device}" - info["latency"] = latency + " (X{})".format(benchmark["repeat"]) - report["profile"] = latency - except Exception as exc: # pylint: disable=broad-exception-caught - msg = f"Failed to profile native: {exc}" - self._logger.warning(self.worker_mark(msg)) - report["profile"] = "failed run native" - return info, report - - def parse(self) -> Tuple[dict, dict]: - """Parse the model to IRModule. - - Returns - ------- - info: dict - The info of parse. - report: dict - The report of parse. - """ - - stage_config = self._config[MSCStage.PARSE] - if self._config.get("use_cache", True): - cache_path = ( - msc_utils.get_cache_dir().create_dir(MSCStage.PARSE).relpath("parsed_relax.json") - ) - else: - cache_path = None - info = {} - if cache_path and os.path.isfile(cache_path): - with open(cache_path) as f: - self._relax_mod = tvm.ir.load_json(f.read()) - info["cache"] = cache_path - else: - info = {"parser": stage_config["parser"], "config": stage_config.get("parse_config")} - parse_config = msc_utils.copy_dict(stage_config.get("parse_config", {})) - parse_config["as_msc"] = False - if self._model_type in self._plugins: - plugin = self._plugins[self._model_type] - parse_config["custom_convert_map"] = plugin.get_convert_map() - self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) - transformed = set() - for stage in [MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in self._config: - continue - run_type = self._config[stage]["run_type"] - if run_type in transformed: - continue - transformed.add(run_type) - runner_cls = self._get_runner_cls(run_type) - if hasattr(runner_cls, "target_transform"): - msg = f"Transform for {run_type}({stage})" - self._logger.info(self.worker_mark(msg)) - self._relax_mod = runner_cls.target_transform(self._relax_mod) - if cache_path: - with open(cache_path, "w") as f: - f.write(tvm.ir.save_json(self._relax_mod)) - msg = "Save parsed mod to " + cache_path - self._logger.debug(self.worker_mark(msg)) - return info, {} - - def get_tool_config(self, tool_type: str, key: str = "tool_config", default: Any = None) -> Any: - """Get the tool config - - Parameters - ---------- - tool_type: str - The tool type. - key: str - The config key - - Returns - ------- - config: - The tool config or info. - """ - - assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) - return self._tools_config[tool_type].get(key, default) - - def tool_applied(self, tool_type: str) -> bool: - """Check if the tool is applied - - Parameters - ---------- - tool_type: str - The tool type. - - Returns - ------- - applied: bool - Whether the tool is applied. - """ - - config = self.get_tool_config(tool_type) - return os.path.isfile(config["plan_file"]) - - def apply_tool( - self, tool_type: str, knowledge: Optional[dict] = None, data_loader: Any = None - ) -> Tuple[dict, dict]: - """Apply tool with runner - - Parameters - ---------- - tool_type: str - The tool type to apply. - knowledge: dict - The pre knowledge. - data_loader: - The data loader. - - Returns - ------- - info: dict - The info of apply tool. - report: dict - The report of apply tool. - """ - - plan_file = self.get_tool_config(tool_type)["plan_file"] - if knowledge: - self._logger.info("Plan by %d knowledge for %s", len(knowledge), tool_type) - msc_utils.save_dict(knowledge, plan_file) - else: - self._runner.make_plan(tool_type, data_loader) - if self.get_tool_config(tool_type, "visualize", False): - self._runner.visualize( - msc_utils.get_visual_dir().create_dir(self._runner.stage.split(".")[0]) - ) - report = {} - if os.path.isfile(plan_file): - report["plan_num"] = len(msc_utils.load_dict(plan_file)) - return {}, report - - def create_runner( - self, - stage: str, - tools: Optional[List[str]] = None, - run_type: Optional[str] = None, - run_config: Optional[dict] = None, - visualize: bool = True, - profile: bool = True, - use_cache: bool = True, - ) -> Tuple[dict, dict]: - """Create runner. - - Parameters - ---------- - stage: str - The stage name - tools: list - The tools to apply. - run_type: str - The type of runner. - run_config: dict - The config of runner. - visualize: bool - Whether to visualize the runner - profile: bool - Whether to profile the runner. - use_cache: bool - Whether to use cache. - - Returns - ------- - info: dict - The info of create runner. - report: dict - The report of create runner. - """ - - if self._runner: - self._runner.destory() - tools = tools or [] - assert all(t in self._tools_config for t in tools), "Missing some tools " + str(tools) - main_stage = stage.split(".")[0] - if not run_type: - run_type = self._config[main_stage]["run_type"] - if not run_config: - run_config = self._config[main_stage].get("run_config", {}) - runner_cls = self._get_runner_cls(run_type) - if "generate_config" not in run_config: - run_config["generate_config"] = {} - cleanup = self._debug_levels.get(stage, 0) == 0 - run_config["generate_config"]["build_folder"] = msc_utils.get_build_dir().create_dir( - stage, cleanup=cleanup - ) - if "device" not in run_config: - run_config["device"] = self._device - if "training" not in run_config: - run_config["training"] = self._training - # Build runner - runner = runner_cls( - self._relax_mod, - tools_config=[self._tools_config[t] for t in tools], - plugin=self._plugins.get(run_type), - stage=stage, - name=self._name, - logger=self._logger, - **run_config, - ) - cache_dir = msc_utils.get_cache_dir().create_dir(stage) if use_cache else None - runner.build(cache_dir=cache_dir) - if visualize: - runner.visualize(msc_utils.get_visual_dir().create_dir(main_stage)) - if use_cache: - runner.save_cache(cache_dir) - info, report = {}, {"runtime": f"{runner.framework} @ {runner.device}"} - if profile and "profile" in self._config[main_stage]: - profile_config = self._config[main_stage]["profile"] - info["profile"], report["profile"] = self._profile_runner(runner, profile_config) - self._runner = runner - return info, report - - def _profile_runner(self, runner: BaseRunner, profile_config: dict) -> Tuple[dict, str]: - """Profile the runner. - - Parameters - ---------- - runner: BaseRunner - The runner to be profiled - profile_config: dict - The config of profile. - - Returns - ------- - info: dict - The info of profile. - report: str - The report of profile. - """ - - stage = runner.stage - info, report = {}, "" - - # check accuracy - check_config = profile_config.get("check", {}) - if check_config: - loader = msc_utils.IODataLoader(self._config["dataset"]["golden"]["loader"]) - acc_info = {"passed": ""} - total, passed = 0, 0 - for idx, (inputs, outputs) in enumerate(loader): - results = runner.run(inputs) - if outputs: - iter_info = msc_utils.compare_arrays( - outputs, - results, - atol=check_config.get("atol", 1e-2), - rtol=check_config.get("rtol", 1e-2), - report_detail=runner.debug_level >= 2, - ) - else: - iter_info = { - "total": len(results), - "passed": len(results), - "info": {k: msc_utils.MSCArray(v).abstract() for k, v in results.items()}, - } - total += iter_info["total"] - passed += iter_info["passed"] - acc_info["iter_" + str(idx)] = iter_info["info"] - pass_rate = float(passed) / total - accuracy = f"{passed}/{total}({pass_rate * 100:.2f}%)" - acc_info["passed"] = f"{accuracy} {check_config}" - info["accuracy"] = acc_info if runner.debug_level >= 1 else accuracy - report = "pass " + accuracy - if runner.get_tool(ToolType.PRUNER) or runner.get_tool(ToolType.QUANTIZER): - disable_msg = f"Disable accuracy check({stage}) by tools" - self._logger.debug(self.worker_mark(disable_msg)) - else: - required_err, err_rate = check_config.get("err_rate", 0), (1 - pass_rate) - if err_rate > required_err >= 0: - self._logger.error(msc_utils.msg_block(self.worker_mark("ACCURACY"), acc_info)) - raise Exception( - f"Failed to profile the runner({stage}), err_rate {err_rate} > required {required_err}" - ) - - # benchmark model - benchmark_config = profile_config.get("benchmark", {}) - if benchmark_config: - for _ in range(benchmark_config.get("warm_up", 10)): - runner.run(self._sample_inputs) - start = time.time() - repeat = self._get_repeat(benchmark_config, runner.device) - for _ in range(repeat): - runner.run(self._sample_inputs) - avg_time = (time.time() - start) * 1000 / repeat - latency = f"{avg_time:.2f} ms @ {runner.device}" - info["latency"] = latency + f" (X{repeat})" - report += (", " if report else "") + latency - return info, report - - def export_model(self, stage: str, folder: msc_utils.MSCDirectory, dump: bool = True) -> Any: - """Export the model - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - dump: bool - Whether to dump info. - - Returns - ------- - exported: - The exported model. - """ - - if stage == MSCStage.COMPILE: - if not dump: - return self._runner.runnable - return self._runner.export_runnable(folder) - - if stage == MSCStage.OPTIMIZE: - module = self._runner.export_module(folder) - if not dump: - return module - path = folder.relpath("model.json") - with open(path, "w") as f: - f.write(tvm.ir.save_json(module)) - return path - - if not dump: - return self._model - dump_func = self._get_runner_cls(self._model_type).dump_nativate - return dump_func(self._model, folder, self._config[MSCStage.EXPORT]) - - def export_tool(self, tool_type: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the tool - - Parameters - ---------- - tool_type: str - The tool type. - folder: MSCDirectory - The export folder. - - Returns - ------- - config: dict - The exported tool config. - """ - - run_tool = self._runner.get_tool(tool_type) - assert tool_type in self._tools_config, "Can not find tool_type " + str(tool_type) - return run_tool.export_config(self._tools_config[tool_type]["tool_config"], folder) - - def export_info(self, stage: str, folder: msc_utils.MSCDirectory) -> dict: - """Export the info of worker - - Parameters - ---------- - stage: str - The pipeline stage. - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The info. - """ - - return { - "visualize": msc_utils.get_visual_dir().copy_to(folder.relpath("visualize")), - "graphs": self._runner.export_graphs(folder.create_dir("graphs")), - } - - def get_runnable(self, ret_type: str = "runner") -> Any: - """Return object by type. - - Parameters - ---------- - ret_type: str - The return type runner| runnable| model. - - Returns - ------- - runnable: - The runner or model. - """ - - assert self._runner, "Failed to create runner, call run_pipe first" - if ret_type == "runner": - return self._runner - if ret_type == "runnable": - return self._runner.runnable - if ret_type == "model": - return self._runner.model - raise TypeError("Unexpect return type " + str(ret_type)) - - def _get_repeat(self, benchmark: dict, device: Optional[str] = None) -> int: - """Get the repeat number for benchmark - - Parameters - ---------- - benchmark: dict - The benchmark config. - device: str - The device name - - Returns - ------- - repeat: int - The repeat number. - """ - - device = device or self._device - repeat = benchmark.get("repeat", -1) - if repeat == -1: - repeat = 500 if device.startswith("cuda") else 10 - return repeat - - def _get_runner_cls(self, run_type: str) -> BaseRunner: - """Get the runner cls by type - - Parameters - ---------- - run_type: str - The run type. - - Returns - ------- - runner_cls: class - The runner class. - """ - - raise NotImplementedError("_get_runner_cls is not implemented in " + str(self.__class__)) - - def destory(self): - """Destroy the worker""" - - if self._runner: - self._runner.destory() - - def worker_mark(self, msg: Any) -> str: - """Mark the message with worker info - - Parameters - ------- - msg: str - The message - - Returns - ------- - msg: str - The message with mark. - """ - - return f"WORKER[{self._name}] {msg}" - - @property - def runner(self): - return self._runner - - @property - def model_type(self): - return self._model_type - - @property - def optimize_type(self): - return self._optimize_type - - @property - def compile_type(self): - return self._compile_type - - -class MSCPipeWorker(BasePipeWorker): - """Normal manager in MSC""" - - def _get_runner_cls(self, run_type: str) -> BaseRunner: - """Get the runner cls by type - - Parameters - ---------- - run_type: str - The run type. - - Returns - ------- - runner_cls: class - The runner class. - """ - - if run_type == MSCFramework.TVM: - from tvm.contrib.msc.framework.tvm.runtime import TVMRunner - - runner_cls = TVMRunner - elif run_type == MSCFramework.TORCH: - from tvm.contrib.msc.framework.torch.runtime import TorchRunner - - runner_cls = TorchRunner - elif run_type == MSCFramework.TENSORFLOW: - from tvm.contrib.msc.framework.tensorflow.runtime import TensorflowRunner - - runner_cls = TensorflowRunner - elif run_type == MSCFramework.TENSORRT: - from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner - - runner_cls = TensorRTRunner - else: - raise Exception("Unexpect run_type " + str(run_type)) - return runner_cls diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py deleted file mode 100644 index 875846f70a28..000000000000 --- a/python/tvm/contrib/msc/pipeline/wrapper.py +++ /dev/null @@ -1,289 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.pipeline.wrapper""" - -import shutil -from typing import Any, List, Optional, Union - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.message import MSCStage -from tvm.contrib.msc.core.utils.namespace import MSCFramework - -from .dynamic import MSCDynamic, TorchDynamic -from .manager import MSCManager -from .utils import create_config - - -class BaseWrapper: - """Base Wrapper of models - - Parameters - ---------- - model: Any - The raw model in framwork. - config: dict - The config for pipeline - plugins: dict - The plugins for pipeline. - """ - - def __init__( - self, - model: Any, - config: dict, - workspace: str = "msc_workspace", - plugins: Optional[dict] = None, - ): - self._meta_model = model - self._optimized_model, self._compiled_model = None, None - self._config = config - self._plugins = plugins - self._dynamic = self._config.get("dynamic", False) - verbose = config.get("verbose", "info") - self._debug = verbose.startswith("debug") - self._workspace = msc_utils.msc_dir(workspace, keep_history=self._debug) - log_path = self._workspace.relpath("MSC_LOG", keep_history=False) - self._config["logger"] = msc_utils.create_file_logger(verbose, log_path) - self._pipeline, self._report = None, None - self.setup() - - def __str__(self): - if self.compiled: - phase = "compiled" - elif self.optimized: - phase = "optimized" - else: - phase = "meta" - return f"({phase}) {self._get_model().__str__()}" - - def __getattr__(self, name): - if hasattr(self._get_model(), name): - return getattr(self._get_model(), name) - return self._get_model().__getattr__(name) - - def setup(self): - """Setup the wrapper""" - - return - - def optimize(self, workspace: str = "Optimize"): - """Optimize the model - - Parameters - ---------- - workspace: str - The workspace. - """ - - self.logger.info(msc_utils.split_line("Start optimize model", "*")) - config = msc_utils.copy_dict(self._config) - config["workspace"] = self._workspace.create_dir(workspace) - if MSCStage.OPTIMIZE not in config: - config[MSCStage.OPTIMIZE] = {"run_type": self.model_type()} - profile = config.get(MSCStage.BASELINE, {}).get("profile") - if profile: - config[MSCStage.OPTIMIZE]["profile"] = profile - self._pipeline = self.pipe_cls(self._meta_model, config, self._plugins, run_compile=False) - self._report = self._pipeline.run_pipe() - if self._report["success"]: - self._optimized_model = self._pipeline.get_runtime("runnable") - return self - - def compile( - self, workspace: str = "Compile", ckpt_path: str = "Checkpoint", dump: bool = False - ): - """Compile the model - - Parameters - ---------- - workspace: str - The workspace. - ckpt_path: str - The path to export checkpoint. - dump: bool - Whether to dump the info. - """ - - if self._optimized_model: - self.logger.info(msc_utils.split_line("Start compile checkpoint", "*")) - ckpt_path = self._workspace.create_dir(ckpt_path).path - export = self.export(ckpt_path, dump=dump, keep_workspace=True) - export["config"]["workspace"] = self._workspace.create_dir(workspace) - self._pipeline = self.pipe_cls( - export["model"], export["config"], export["plugins"], root=ckpt_path - ) - self._report = self._pipeline.run_pipe() - if self._report["success"]: - self._compiled_model = self._pipeline.get_runtime("runnable") - if not self._debug: - shutil.rmtree(ckpt_path) - else: - self.logger.info(msc_utils.split_line("Start compile model", "*")) - config = msc_utils.copy_dict(self._config) - config["workspace"] = self._workspace.create_dir(workspace) - self._pipeline = self.pipe_cls(self._meta_model, config, self._plugins) - self._report = self._pipeline.run_pipe() - if self._report["success"]: - self._compiled_model = self._pipeline.get_runtime("runnable") - return self - - def export( - self, path: str = "msc_export", dump: bool = True, keep_workspace: bool = False - ) -> Union[str, dict]: - """Export compile pipeline - - Parameters - ---------- - path: str - The export path. - dump: bool - Whether to dump the info. - keep_workspace: bool - Whether to keep workspace. - - Returns - ------- - export_path/pipeline: str/dict - The exported path/pipeline info. - """ - - if not self._pipeline: - self._pipeline = self.pipe_cls(self._meta_model, self._config, self._plugins) - exported = self._pipeline.export(path, dump=dump) - if not self._debug: - self._pipeline.destory() - if not keep_workspace: - self._workspace.destory() - return exported - - def _get_model(self) -> Any: - return self._compiled_model or self._optimized_model or self._meta_model - - def _get_framework(self) -> str: - return self._pipeline.get_runtime().framework if self._pipeline else self.model_type() - - @property - def pipe_cls(self): - if self._dynamic: - return MSCDynamic - return MSCManager - - @property - def optimized(self): - return self._optimized_model is not None - - @property - def compiled(self): - return self._compiled_model is not None - - @property - def device(self): - if self._pipeline: - return self._pipeline.get_runtime().device - return "cpu" - - @property - def logger(self): - return self._config["logger"] - - @property - def report(self): - return self._report - - @classmethod - def create_config( - cls, - inputs: List[dict], - outputs: List[str], - baseline_type: Optional[str] = None, - optimize_type: Optional[str] = None, - compile_type: Optional[str] = None, - **kwargs, - ) -> dict: - """Create config for msc pipeline - - Parameters - ---------- - inputs: list - The inputs info, - outputs: list - The output names. - baseline_type: str - The baseline type. - optimize_type: str - The optimize type. - compile_type: str - The compile type. - kwargs: dict - The config kwargs. - """ - - return create_config( - inputs, outputs, cls.model_type(), baseline_type, optimize_type, compile_type, **kwargs - ) - - @classmethod - def model_type(cls): - return MSCFramework.MSC - - -class TorchWrapper(BaseWrapper): - """Wrapper of torch models""" - - def __call__(self, *inputs): - return self.forward(*inputs) - - def forward(self, *inputs): - framework = self._get_framework() - if framework != MSCFramework.TORCH: - inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs] - outputs = self._get_model()(*inputs) - if framework == MSCFramework.TORCH: - return outputs - if isinstance(outputs, (tuple, list)): - return [msc_utils.cast_array(o, MSCFramework.TORCH, self.device) for o in outputs] - return msc_utils.cast_array(outputs, MSCFramework.TORCH, self.device) - - def parameters(self): - framework = self._get_framework() - if framework == MSCFramework.TORCH: - return self._get_model().parameters() - return self._pipeline.get_runtime().get_weights(MSCFramework.TORCH) - - def train(self): - if self._pipeline: - self._pipeline.get_runtime().train() - if self._get_framework() == MSCFramework.TORCH: - return self._get_model().train() - return self._get_model() - - def eval(self): - if self._pipeline: - self._pipeline.get_runtime().eval() - if self._get_framework() == MSCFramework.TORCH: - return self._get_model().eval() - return self._get_model() - - @property - def pipe_cls(self): - if self._dynamic: - return TorchDynamic - return MSCManager - - @classmethod - def model_type(cls): - return MSCFramework.TORCH diff --git a/python/tvm/contrib/msc/plugin/__init__.py b/python/tvm/contrib/msc/plugin/__init__.py deleted file mode 100644 index 53b4774db1b5..000000000000 --- a/python/tvm/contrib/msc/plugin/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin""" - -from .build import * diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py deleted file mode 100644 index 88f9204f3a02..000000000000 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.plugin", __name__) diff --git a/python/tvm/contrib/msc/plugin/build.py b/python/tvm/contrib/msc/plugin/build.py deleted file mode 100644 index b3afccb10dbf..000000000000 --- a/python/tvm/contrib/msc/plugin/build.py +++ /dev/null @@ -1,285 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.build""" - -import os -import subprocess -import sys -from typing import Any, Dict, List, Optional - -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.plugin.codegen import get_codegen - -from .register import register_plugin - - -def _build_plugins( - plugins: Dict[str, dict], - frameworks: List[str], - workspace: msc_utils.MSCDirectory = None, - codegen_config: Optional[Dict[str, str]] = None, - cpp_print_config: Optional[Dict[str, str]] = None, - py_print_config: Optional[Dict[str, str]] = None, - externs_dir: msc_utils.MSCDirectory = None, - on_debug: bool = False, -): - """Build the plugins - - Parameters - ---------- - plugins: dict - The plugins define. - frameworks: list - The frameworks for plugin. - workspace: MSCDirectory - The workspace folder. - codegen_config: dict - The config to generate code. - cpp_print_config: dict - The config to print cpp code. - py_print_config: dict - The config to print python code. - externs_dir: MSCDirectory - The extern sources folder. - on_debug: bool - Whether to debug the building. - """ - - workspace = workspace or msc_utils.msc_dir("msc_plugin") - - # register the plugins - extern_sources, extern_libs, ops_info = {}, {}, {} - for name, plugin in plugins.items(): - sources, libs, info = register_plugin(name, plugin, externs_dir) - extern_sources.update(sources) - extern_libs.update(libs) - ops_info[name] = info - # build plugins for frameworks - codegens = {} - for framework in frameworks: - codegen = get_codegen( - framework, - workspace, - codegen_config, - cpp_print_config=cpp_print_config, - py_print_config=py_print_config, - extern_sources=extern_sources, - extern_libs=extern_libs, - on_debug=on_debug, - ) - if not codegen.libs_built(): - codegen.build_libs() - if codegen.need_manager and not codegen.manager_built(): - codegen.build_manager(ops_info) - codegens[framework] = codegen - return codegens - - -def build_plugins( - plugins: Dict[str, dict], - frameworks: List[str], - workspace: msc_utils.MSCDirectory = None, - codegen_config: Optional[Dict[str, str]] = None, - cpp_print_config: Optional[Dict[str, str]] = None, - py_print_config: Optional[Dict[str, str]] = None, - externs_dir: msc_utils.MSCDirectory = None, - on_debug: bool = False, -) -> Dict[str, Any]: - """Build the plugins and load plugin manager - - Parameters - ---------- - plugins: dict - The plugins define. - frameworks: list - The frameworks for plugin. - workspace: MSCDirectory - The workspace folder. - codegen_config: dict - The config to generate code. - cpp_print_config: dict - The config to print cpp code. - py_print_config: dict - The config to print python code. - externs_dir: MSCDirectory - The extern sources folder. - on_debug: bool - Whether to debug the building. - - Returns - ------- - managers: dict - The plugin managers. - """ - - codegens = _build_plugins( - plugins, - frameworks, - workspace, - codegen_config=codegen_config, - cpp_print_config=cpp_print_config, - py_print_config=py_print_config, - externs_dir=externs_dir, - on_debug=on_debug, - ) - managers = {} - for name, codegen in codegens.items(): - manager_file = codegen.manager_folder.relpath("manager.py") - manager_cls = msc_utils.load_callable(manager_file + ":PluginManager") - managers[name] = manager_cls(codegen.output_folder.path) - return managers - - -def pack_plugins( - plugins: Dict[str, dict], - frameworks: List[str], - project_name: str = "msc_plugin", - codegen_config: Optional[Dict[str, str]] = None, - cpp_print_config: Optional[Dict[str, str]] = None, - py_print_config: Optional[Dict[str, str]] = None, - externs_dir: msc_utils.MSCDirectory = None, - setup_config: Optional[Dict[str, str]] = None, - on_debug: bool = False, -) -> str: - """Build the plugins and build to wheel - - Parameters - ---------- - plugins: dict - The plugins define. - frameworks: list - The frameworks for plugin. - project_name: str - The project name - codegen_config: dict - The config to generate code. - cpp_print_config: dict - The config to print cpp code. - py_print_config: dict - The config to print python code. - externs_dir: MSCDirectory - The extern sources folder. - setup_config: dict - The config to setup wheel. - on_debug: bool - Whether to debug the building. - - Returns - ------- - wheel_path: str - The file path of wheel. - """ - - project_dir = msc_utils.msc_dir(project_name) - workspace = project_dir.create_dir(project_name) - codegens = _build_plugins( - plugins, - frameworks, - workspace, - codegen_config=codegen_config, - cpp_print_config=cpp_print_config, - py_print_config=py_print_config, - externs_dir=externs_dir, - on_debug=on_debug, - ) - # add init files - init_code = """# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from .manager import * -""" - with open(workspace.relpath("__init__.py"), "w") as f: - f.write(init_code) - for name in codegens: - with open(workspace.create_dir(name).relpath("__init__.py"), "w") as f: - f.write(init_code) - - # add setup file - if setup_config: - setup_config_str = "\n " + "\n ".join( - [f"{k} = {v}," for k, v in setup_config.items()] - ) - else: - setup_config_str = "" - setup_code = """ -import os -import shutil - -from setuptools import find_packages, setup -from setuptools.dist import Distribution - -project_name = "{0}" -data_files = [] -for framework in [{2}]: - for folder in ["lib", "include"]: - src_path = os.path.join(project_name, framework, folder) - data_files.append( - ( - os.path.join(project_name, framework, folder), - [os.path.join(src_path, f) for f in os.listdir(src_path)], - ), - ) - -class BinaryDistribution(Distribution): - def has_ext_modules(self): - return True - - def is_pure(self): - return False - -setup( - name="{0}"{1}, - packages=find_packages(), - distclass=BinaryDistribution, - data_files=data_files -) - -shutil.rmtree("build") -shutil.rmtree("{0}.egg-info") -""".format(project_name, setup_config_str, ",".join([f'"{f}"' for f in frameworks])) - with open(project_dir.relpath("setup.py"), "w") as f: - f.write(setup_code) - - # build the wheel - with project_dir: - command = f"{sys.executable} setup.py bdist_wheel" - with open("build.log", "w") as log_f: - process = subprocess.Popen(command, stdout=log_f, stderr=log_f, shell=True) - process.wait() - assert process.returncode == 0, ( - f"Failed to build wheel under {os.getcwd()}, check build.log for detail" - ) - dist_dir = project_dir.create_dir("dist") - files = list(dist_dir.listdir()) - assert len(files) == 1 and files[0].endswith(".whl"), ( - "Failed to build wheel, no .whl found @ " + str(dist_dir.path) - ) - return dist_dir.relpath(files[0]) diff --git a/python/tvm/contrib/msc/plugin/codegen/__init__.py b/python/tvm/contrib/msc/plugin/codegen/__init__.py deleted file mode 100644 index fbc0b0fed8a0..000000000000 --- a/python/tvm/contrib/msc/plugin/codegen/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.codegen""" - -from .codegen import * diff --git a/python/tvm/contrib/msc/plugin/codegen/codegen.py b/python/tvm/contrib/msc/plugin/codegen/codegen.py deleted file mode 100644 index c2debbedd381..000000000000 --- a/python/tvm/contrib/msc/plugin/codegen/codegen.py +++ /dev/null @@ -1,317 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E741 -"""tvm.contrib.msc.core.codegen.codegen""" - -import os -import subprocess -from typing import Dict, List, Optional - -import tvm -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.plugin import _ffi_api - -from .sources import get_plugin_sources - - -class BasePluginCodeGen: - """Manager class to generate codes and build plugin - - Parameters - ---------- - workspace: MSCDirectory - The workspace folder. - codegen_config: dict - The config to generate code. - cpp_print_config: dict - The config to print cpp code. - py_print_config: dict - The config to print python code. - extern_sources: dict - The depend source files. - extern_libs: dict - The depend lib files. - on_debug: bool - Whether to debug the building. - """ - - def __init__( - self, - workspace: msc_utils.MSCDirectory, - codegen_config: Optional[Dict[str, str]] = None, - cpp_print_config: Optional[Dict[str, str]] = None, - py_print_config: Optional[Dict[str, str]] = None, - extern_sources: Optional[Dict[str, str]] = None, - extern_libs: Optional[Dict[str, str]] = None, - on_debug: bool = False, - ): - self._codegen_config = msc_utils.copy_dict(codegen_config) - self._cpp_print_config = msc_utils.dump_dict(cpp_print_config) - self._py_print_config = msc_utils.dump_dict(py_print_config) - self._build_folder = workspace.create_dir( - "source_" + self.framework, keep_history=on_debug, cleanup=not on_debug - ) - self._output_folder = workspace.create_dir(self.framework) - self._extern_sources = extern_sources or {} - self._extern_libs = extern_libs or {} - self.setup() - - def setup(self): - """Set up the codegen""" - - self._lib_folder = self._output_folder.create_dir("lib") - self._manager_folder = self._output_folder - self._libs = [os.path.basename(l) for l in self._extern_libs.values()] - self._libs.extend([os.path.basename(l) for l in self._lib_folder.listdir()]) - self._project_name = f"msc_{self.framework}_plugin" - self._codegen_config.update( - { - "install_dir": self._output_folder.path, - "project_name": self._project_name, - "version": msc_utils.get_version(self.framework), - } - ) - - def libs_built(self) -> bool: - """Check if the libs are built - - Returns - ------- - libs_built: bool - Whether libs are built. - """ - - return any(self._project_name in f for f in self._lib_folder.listdir()) - - def build_libs(self) -> List[str]: - """Generate source and build the lib - - Returns - ------- - paths: list - The lib file paths. - """ - - codegen_config = msc_utils.dump_dict(self._codegen_config) - sources = self.source_getter(codegen_config, self._cpp_print_config, "build") - with self._build_folder as folder: - # add depends - with folder.create_dir("src") as src_folder: - for name, file in self._extern_sources.items(): - src_folder.copy(file, name) - for name, source in get_plugin_sources().items(): - src_folder.add_file(name, source) - for name, source in sources.items(): - if name == "CMakeLists.txt": - folder.add_file(name, source) - else: - src_folder.add_file(name, source) - with folder.create_dir("build"): - command = "cmake ../ && make" - with open("codegen.log", "w") as log_f: - process = subprocess.Popen(command, stdout=log_f, stderr=log_f, shell=True) - process.wait() - assert process.returncode == 0, ( - f"Failed to build plugin under {os.getcwd()}, check codegen.log for detail" - ) - self._libs.extend([os.path.basename(l) for l in self._lib_folder.listdir()]) - return self._lib_folder.listdir(as_abs=True) - - def manager_built(self) -> bool: - """Check if the manager are built - - Returns - ------- - manager_built: bool - Whether manager is built. - """ - - return os.path.isfile(self._manager_folder.relpath("manager.py")) - - def build_manager(self, ops_info: dict) -> List[str]: - """Generate manager source for plugin - - Parameters - ---------- - ops_info: dict - The info of ops. - - Returns - ------- - paths: list - The manager file paths. - """ - - self._codegen_config["libs"] = self._libs - self._codegen_config["ops_info"] = {n: msc_utils.dump_dict(i) for n, i in ops_info.items()} - codegen_config = msc_utils.dump_dict(self._codegen_config) - sources = self.source_getter(codegen_config, self._py_print_config, "manager") - manager_files = [] - with self._manager_folder as folder: - for name, source in sources.items(): - manager_files.append(folder.add_file(name, source)) - return manager_files - - @property - def source_getter(self): - raise NotImplementedError("source_getter is not supported for Base codegen") - - @property - def need_manager(self): - return True - - @property - def framework(self): - return MSCFramework.MSC - - @property - def output_folder(self): - return self._output_folder - - @property - def lib_folder(self): - return self._lib_folder - - @property - def manager_folder(self): - return self._manager_folder - - -class TVMPluginCodegen(BasePluginCodeGen): - """Plugin codegen for tvm""" - - def setup(self): - """Set up the codegen""" - - super().setup() - tvm_root = os.path.dirname(os.path.dirname(tvm.__path__[0])) - self._codegen_config.update( - {"need_convert": False, "with_runtime": True, "tvm_root": tvm_root} - ) - - @property - def source_getter(self): - return _ffi_api.GetTVMPluginSources - - @property - def framework(self): - return MSCFramework.TVM - - -class TorchPluginCodegen(BasePluginCodeGen): - """Plugin codegen for torch""" - - def setup(self): - """Set up the codegen""" - # pylint: disable=import-outside-toplevel - import torch.utils - - super().setup() - self._codegen_config.update( - { - "need_convert": True, - "with_runtime": False, - "torch_prefix": torch.utils.cmake_prefix_path, - } - ) - - @property - def source_getter(self): - return _ffi_api.GetTorchPluginSources - - @property - def framework(self): - return MSCFramework.TORCH - - -class TensorRTPluginCodegen(BasePluginCodeGen): - """Plugin codegen for tensorrt""" - - def setup(self): - """Set up the codegen""" - # pylint: disable=import-outside-toplevel - from tvm.contrib.msc.framework.tensorrt import _ffi_api as _trt_api - - super().setup() - self._codegen_config.update( - { - "need_convert": False, - "with_runtime": False, - "tensorrt_root": _trt_api.GetTensorRTRoot(), - } - ) - - @property - def source_getter(self): - return _ffi_api.GetTensorRTPluginSources - - @property - def framework(self): - return MSCFramework.TENSORRT - - -def get_codegen( - framework: str, - workspace: msc_utils.MSCDirectory, - codegen_config: Optional[Dict[str, str]] = None, - cpp_print_config: Optional[Dict[str, str]] = None, - py_print_config: Optional[Dict[str, str]] = None, - extern_sources: Optional[Dict[str, str]] = None, - extern_libs: Optional[Dict[str, str]] = None, - on_debug: bool = False, -): - """Create codegen for framework - - Parameters - ---------- - framework: str - THe framework for the plugin. - workspace: MSCDirectory - The workspace folder. - codegen_config: dict - The config to generate code. - cpp_print_config: dict - The config to print cpp code. - py_print_config: dict - The config to print python code. - extern_sources: dict - The depend source files. - extern_libs: dict - The depend lib files. - on_debug: bool - Whether to debug the building. - """ - - codegen_cls = None - if framework == MSCFramework.TVM: - codegen_cls = TVMPluginCodegen - elif framework == MSCFramework.TORCH: - codegen_cls = TorchPluginCodegen - elif framework == MSCFramework.TENSORRT: - codegen_cls = TensorRTPluginCodegen - else: - raise NotImplementedError(f"framework {framework} is not support for plugin codegen") - return codegen_cls( - workspace, - codegen_config=codegen_config, - cpp_print_config=cpp_print_config, - py_print_config=py_print_config, - extern_sources=extern_sources, - extern_libs=extern_libs, - on_debug=on_debug, - ) diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py b/python/tvm/contrib/msc/plugin/codegen/sources.py deleted file mode 100644 index ee2c5e992234..000000000000 --- a/python/tvm/contrib/msc/plugin/codegen/sources.py +++ /dev/null @@ -1,1169 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501 -"""tvm.contrib.msc.plugin.codegen.sources""" - -from typing import Dict - - -def get_plugin_base_h_code() -> str: - """Create plugin base header file codes - - Returns - ------- - source: str - The plugin base header source. - """ - - return """#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ -#define TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { -namespace plugin { - -typedef enum { - kUINT8 = 0, - kINT8 = 1, - kINT16 = 2, - kINT32 = 3, - kINT64 = 4, - kFLOAT16 = 5, - kFLOAT32 = 6, - kFLOAT64 = 7, - kUNKNOWN = 8, -} MetaDataType; - -class MetaShape { - public: - MetaShape() { shape_.resize(0); } - - MetaShape(const std::vector& shape) { - for (auto d : shape) { - shape_.push_back(d); - } - } - - template - void SetShape(const std::vector& shape) { - for (auto d : shape) { - shape_.push_back(static_cast(d)); - } - } - - template - void SetDim(int index, T dim) { - int valid_index = index < 0 ? shape_.size() + index : index; - if (valid_index >= shape_.size()) { - std::string err = - std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); - throw std::runtime_error(err); - } - shape_[valid_index] = dim; - } - - template - const std::vector GetShape() const { - std::vector shape; - for (auto d : shape_) { - shape.push_back(d); - } - return shape; - } - - inline int64_t DimAt(int index) const { - int valid_index = index < 0 ? shape_.size() + index : index; - if (valid_index >= shape_.size()) { - std::string err = - std::to_string(index) + " out of dims size " + std::to_string(shape_.size()); - throw std::runtime_error(err); - } - return shape_[valid_index]; - } - - inline size_t ndim() const { return shape_.size(); } - - inline const std::vector shape() const { return shape_; } - - inline size_t size() const { - size_t size = 1; - for (auto d : shape_) { - assert(d > 0 && "Can not compute static size with unknow dim"); - size *= d; - } - return size; - } - - inline int64_t operator[](int index) const { return DimAt(index); } - - friend std::ostream& operator<<(std::ostream& out, const MetaShape& shape) { - for (size_t i = 0; i < shape.ndim(); i++) { - out << shape.DimAt(i) << (1 < shape.ndim() ? "" : ","); - } - return out; - } - - private: - std::vector shape_; -}; - -class MetaLayoutAxis { - public: - MetaLayoutAxis(const char name, size_t factor = 0) : factor_(factor) { - name_ = (factor == 0 ? "" : std::to_string(factor)) + std::string(1, name); - } - - MetaLayoutAxis(const std::string& name) { - if (name.size() == 1) { - factor_ = 0; - name_ = name; - } else { - factor_ = std::stoi(name.substr(1)); - name_ = name.substr(0, 1); - } - } - - inline const std::string name() const { return name_; } - - inline size_t factor() const { return factor_; } - - private: - std::string name_; - size_t factor_; -}; - -class MetaLayout { - public: - MetaLayout() {} - - MetaLayout(const std::string& name) : name_(name) { - int factor = 0; - for (char c : name) { - if (c >= 'A' && c <= 'Z') { - assert(factor == 0 && "Upper layout axis do not accept factor"); - MetaLayoutAxis axis(c); - axes_.push_back(axis); - } else if (c >= 'a' && c <= 'z') { - assert(factor > 0 && "Lower layout axis should has factor"); - MetaLayoutAxis axis(c, factor); - axes_.push_back(axis); - factor = 0; - } else if (c >= '0' && c <= '9') { - assert(factor >= 0 && "Factor number should between 0 and 9"); - factor = factor * 10 + c - '0'; - } else { - throw std::runtime_error("Unexpected layout axis " + name); - } - } - CheckValid(); - } - - MetaLayout(const std::vector& axes) : axes_(axes) { - name_ = ""; - for (auto a : axes_) { - name_ += (a.factor() == 0 ? "" : std::to_string(a.factor())) + a.name(); - } - CheckValid(); - }; - - void CheckValid() { - std::set recorded_axes; - for (auto a : axes_) { - auto axis_name = a.name(); - assert(!recorded_axes.count(axis_name) && ("Has duplicate layout axis in " + name_).c_str()); - recorded_axes.insert(axis_name); - } - } - - inline const MetaLayoutAxis AxisAt(int index) const { - int valid_index = index < 0 ? axes_.size() + index : index; - if (valid_index >= axes_.size()) { - std::string err = std::to_string(index) + " out of axes size " + std::to_string(axes_.size()); - throw std::runtime_error(err); - } - return axes_[valid_index]; - } - - inline MetaLayoutAxis operator[](int index) { return AxisAt(index); } - - inline size_t ndim() const { return axes_.size(); } - - inline std::string name() const { return name_; } - - friend std::ostream& operator<<(std::ostream& out, const MetaLayout& layout) { - out << layout.name(); - return out; - } - - private: - std::string name_; - std::vector axes_; -}; - -class MetaTensor { - public: - MetaTensor() {} - - MetaTensor(const MetaShape& shape, const MetaDataType& data_type, - const MetaLayout& layout = MetaLayout()) - : shape_(shape), data_type_(data_type), layout_(layout) {} - - inline const MetaShape shape() const { return shape_; } - - inline MetaDataType data_type() const { return data_type_; } - - inline const std::vector meta_shape() const { return shape_.shape(); } - - inline const MetaLayout layout() const { return layout_; } - - inline const std::string layout_name() const { return layout_.name(); } - - inline size_t ndim() const { return shape_.ndim(); } - - inline size_t size(bool count_batch = true) const { - if (count_batch) { - size_t batch_dim = 0; - for (size_t i = 0; i < layout_.ndim(); i++) { - if (layout_.AxisAt(i).name() == "N") { - batch_dim = i; - } - } - return shape_.size() / shape_.shape()[batch_dim]; - } - return shape_.size(); - } - - inline MetaLayoutAxis AxisAt(int index) const { return layout_.AxisAt(index); } - - inline int AxisOf(const std::string& axis) const { - for (size_t i = 0; i < layout_.ndim(); i++) { - if (layout_.AxisAt(i).name() == axis) { - return i; - } - } - return -1; - } - - inline int64_t DimAt(int index) const { return shape_.DimAt(index); } - - inline int64_t DimAt(const std::string& axis) const { - int idx = AxisOf(axis); - if (idx >= 0) { - return shape_.DimAt(idx); - } - throw std::runtime_error("Can not find dim for " + axis); - } - - friend std::ostream& operator<<(std::ostream& out, const MetaTensor& tensor) { - out << "tensor : <" << tensor.shape() << ">, (" << tensor.layout() << ")"; - return out; - } - - private: - MetaShape shape_; - MetaDataType data_type_; - MetaLayout layout_; -}; - -template -class DataTensor : public MetaTensor { - public: - DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, T* data) - : MetaTensor(shape, data_type, layout) { - data_ = data; - } - - DataTensor(const MetaShape shape, const MetaDataType& data_type, const MetaLayout layout, - const T* data) - : MetaTensor(shape, data_type, layout) { - data_ = const_cast(data); - } - - T* data() const { return data_; } - - const T* const_data() const { return data_; } - - private: - T* data_{nullptr}; -}; - -} // namespace plugin -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_BASE_H_ -""" - - -def _get_common_utils() -> str: - """Get the utils for common - - Returns - ------- - source: str - The plugin utils for common. - """ - - return """class SerializeUtils { - public: - // Helper function for serializing plugin attrs - template - static const std::string ToString(const T& value) { - return std::to_string(value); - } - - static std::string ToString(const std::string& value) { return value; } - - template - static std::string ToString(const std::vector& value) { - std::string str = std::to_string(value.size()); - for (const auto& v : value) { - str += "," + std::to_string(v); - } - return str; - } - - static void FromString(const std::string& src, std::string& target) { target = src; } - - static void FromString(const std::string& src, bool& target) { - target = std::stoi(src) > 0 ? true : false; - } - - static void FromString(const std::string& src, int& target) { target = std::stoi(src); } - - static void FromString(const std::string& src, size_t& target) { target = std::stoi(src); } - - static void FromString(const std::string& src, long& target) { target = std::stol(src); } - - static void FromString(const std::string& src, float& target) { target = std::stod(src); } - - static void FromString(const std::string& src, double& target) { target = std::stof(src); } - - template - static void FromString(const std::string& src, std::vector& target) { - std::string left_str = src; - int pos = left_str.find(","); - if (pos == std::string::npos) { - return; - } - assert(pos > 0); - size_t src_size; - FromString(left_str.substr(0, pos), src_size); - target.resize(src_size); - for (size_t i = 0; i < src_size; i++) { - pos = left_str.find(","); - left_str = left_str.substr(pos + 1); - FromString(left_str, target[i]); - } - } - - static void FromString(const std::string& src, std::vector& target) { - std::vector values; - FromString(src, values); - target.resize(values.size()); - for (size_t i = 0; i < values.size(); i++) { - target[i] = values[i] > 0 ? true : false; - } - } -}; - -class DataUtils { - public: - static MetaDataType ToMetaType(const std::string& name) { - MetaDataType dtype; - if (name == "int8") { - dtype = MetaDataType::kINT8; - } else if (name == "uint8" || name == "char") { - dtype = MetaDataType::kUINT8; - } else if (name == "int16") { - dtype = MetaDataType::kINT16; - } else if (name == "int32" || name == "int") { - dtype = MetaDataType::kINT32; - } else if (name == "int64" || name == "long") { - dtype = MetaDataType::kINT64; - } else if (name == "float16" || name == "half") { - dtype = MetaDataType::kFLOAT16; - } else if (name == "float32" || name == "float") { - dtype = MetaDataType::kFLOAT32; - } else if (name == "float64" || name == "double") { - dtype = MetaDataType::kFLOAT64; - } else { - dtype = MetaDataType::kUNKNOWN; - } - return dtype; - } - - static bool IsListType(const std::string& dtype) { - int pos = dtype.find("list("); - return pos == 0; - } - - static const std::string GetEleType(const std::string& dtype) { - int pos = dtype.find("list("); - if (pos == 0) { - return dtype.substr(pos + 5, dtype.size() - 6); - } - return ""; - } -}; -""" - - -def _get_tvm_utils() -> str: - """Get the utils for tvm - - Returns - ------- - source: str - The plugin utils for tvm. - """ - - return """ -#ifdef PLUGIN_SUPPORT_TVM -using namespace tvm::relax; -using namespace tvm::runtime; -class TVMUtils { - public: - static void AttrFromPrim(const PrimValue& expr, std::string& target) { - ICHECK(expr->IsInstance()) << "Expr is not StringImm"; - target = Downcast(expr)->value; - } - - static void AttrFromPrim(const PrimValue& expr, bool& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; - target = Downcast(expr->value)->value; - } - - static void AttrFromPrim(const PrimValue& expr, int& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; - target = Downcast(expr->value)->value; - } - - static void AttrFromPrim(const PrimValue& expr, size_t& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; - target = Downcast(expr->value)->value; - } - - static void AttrFromPrim(const PrimValue& expr, long& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not IntImm"; - target = Downcast(expr->value)->value; - } - - static void AttrFromPrim(const PrimValue& expr, float& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; - target = Downcast(expr->value)->value; - } - - static void AttrFromPrim(const PrimValue& expr, double& target) { - ICHECK(expr->value->IsInstance()) << "Expr value is not FloatImm"; - target = Downcast(expr->value)->value; - } - - template - static void AttrFromPrims(const Tuple& tuple, std::vector& target) { - for (size_t i = 0; i < tuple->fields.size(); i++) { - ICHECK(tuple->fields[i]->IsInstance()) << "Field is not PrimValue"; - AttrFromPrim(Downcast(tuple->fields[i]), target[i]); - } - } - - static void AttrFromArg(const ffi::AnyView& arg, std::string& target) { - target = arg.cast(); - } - - static void AttrFromArg(const ffi::AnyView& arg, bool& target) { target = arg; } - - static void AttrFromArg(const ffi::AnyView& arg, int& target) { target = arg; } - - static void AttrFromArg(const ffi::AnyView& arg, size_t& target) { target = int(arg); } - - static void AttrFromArg(const ffi::AnyView& arg, long& target) { target = int64_t(arg); } - - static void AttrFromArg(const ffi::AnyView& arg, float& target) { target = double(arg); } - - static void AttrFromArg(const ffi::AnyView& arg, double& target) { target = arg; } - - template - static void AttrFromArgs(const ffi::PackedArgs& args, size_t start, size_t num, std::vector& target) { - for (size_t i = 0; i < num; i++) { - AttrFromArg(args[start + i], target[i]); - } - } - - static MetaDataType ToMetaType(const DataType& dtype) { - MetaDataType meta_type; - if (dtype.code() == 0 && dtype.bits() == 8) { - meta_type = MetaDataType::kINT8; - } else if (dtype.code() == 0 && dtype.bits() == 16) { - meta_type = MetaDataType::kINT16; - } else if (dtype.code() == 0 && dtype.bits() == 32) { - meta_type = MetaDataType::kINT32; - } else if (dtype.code() == 0 && dtype.bits() == 64) { - meta_type = MetaDataType::kINT64; - } else if (dtype.code() == 1 && dtype.bits() == 8) { - meta_type = MetaDataType::kUINT8; - } else if (dtype.code() == 2 && dtype.bits() == 16) { - meta_type = MetaDataType::kFLOAT16; - } else if (dtype.code() == 2 && dtype.bits() == 32) { - meta_type = MetaDataType::kFLOAT32; - } else if (dtype.code() == 2 && dtype.bits() == 64) { - meta_type = MetaDataType::kFLOAT64; - } else { - meta_type = MetaDataType::kUNKNOWN; - } - return meta_type; - } - - static MetaDataType ToMetaType(const DLDataType& dtype) { - MetaDataType meta_type; - if (dtype.code == 0U && dtype.bits == 8) { - meta_type = MetaDataType::kINT8; - } else if (dtype.code == 0U && dtype.bits == 16) { - meta_type = MetaDataType::kINT16; - } else if (dtype.code == 0U && dtype.bits == 32) { - meta_type = MetaDataType::kINT32; - } else if (dtype.code == 0U && dtype.bits == 64) { - meta_type = MetaDataType::kINT64; - } else if (dtype.code == 1U && dtype.bits == 8) { - meta_type = MetaDataType::kUINT8; - } else if (dtype.code == 2U && dtype.bits == 16) { - meta_type = MetaDataType::kFLOAT16; - } else if (dtype.code == 2U && dtype.bits == 32) { - meta_type = MetaDataType::kFLOAT32; - } else if (dtype.code == 2U && dtype.bits == 64) { - meta_type = MetaDataType::kFLOAT64; - } else { - meta_type = MetaDataType::kUNKNOWN; - } - return meta_type; - } - - static MetaShape ToMetaShape(const Optional>& tvm_shape) { - if (tvm_shape.defined()) { - std::vector shape_data; - for (auto s : tvm_shape.value()) { - if (s->IsInstance()) { - shape_data.push_back(Downcast(s)->value); - } else { - shape_data.push_back(-1); - } - } - return MetaShape(shape_data); - } - return MetaShape(); - } - - static MetaShape ToMetaShape(DLTensor* tensor, bool as_data = true) { - std::vector dims; - if (as_data) { - assert(tensor->ndim == 1); - assert(TVMUtils::ToMetaType(tensor->dtype) == MetaDataType::kINT64); - int64_t* data_ptr = (int64_t*)tensor->data; - for (size_t i = 0; i < tensor->shape[0]; i++) { - dims.push_back(data_ptr[i]); - } - } else { - for (size_t i = 0; i < tensor->ndim; i++) { - dims.push_back(tensor->shape[i]); - } - } - return MetaShape(dims); - } - - static MetaTensor ToMetaTensor(const Expr& expr, - const LayoutDecision& layout_dec = LayoutDecision()) { - const auto* sinfo = GetStructInfoAs(expr); - if (layout_dec.defined() && layout_dec->layout.defined()) { - const auto& layout = MetaLayout(layout_dec->layout.name()); - return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); - } - const auto& layout = MetaLayout(SpanUtils::GetAttr(expr->span, "layout")); - return MetaTensor(ToMetaShape(sinfo->GetShape()), ToMetaType(sinfo->dtype), layout); - } - - template - static DataTensor ToDataTensor(DLTensor* tensor, bool read_only) { - if (read_only) { - return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), - (const T*)(tensor->data)); - } else { - return DataTensor(ToMetaShape(tensor, false), ToMetaType(tensor->dtype), MetaLayout(), - (T*)(tensor->data)); - } - } - - static DataType ToTVMType(const MetaDataType& dtype) { - DataType tvm_type; - if (dtype == MetaDataType::kINT8) { - tvm_type = DataType::Int(8); - } else if (dtype == MetaDataType::kINT16) { - tvm_type = DataType::Int(16); - } else if (dtype == MetaDataType::kINT32) { - tvm_type = DataType::Int(32); - } else if (dtype == MetaDataType::kINT64) { - tvm_type = DataType::Int(64); - } else if (dtype == MetaDataType::kFLOAT16) { - tvm_type = DataType::Float(16); - } else if (dtype == MetaDataType::kFLOAT32) { - tvm_type = DataType::Float(32); - } else if (dtype == MetaDataType::kFLOAT64) { - tvm_type = DataType::Float(64); - } else { - throw std::runtime_error("Unsupported type"); - } - return tvm_type; - } - - static DataType ToTVMType(const std::string& dtype) { - return ToTVMType(DataUtils::ToMetaType(dtype)); - } - - static Array ToTVMShape(const MetaShape& meta_shape) { - Array tvm_shape; - for (size_t i = 0; i < meta_shape.ndim(); i++) { - auto dim = meta_shape.DimAt(i); - tvm_shape.push_back(Integer(dim)); - } - return tvm_shape; - } - - static void FillDLShape(const MetaShape& shape, DLTensor* data) { - auto shape_data = static_cast(data->data); - for (size_t i = 0; i < shape.ndim(); i++) { - shape_data[i] = shape.DimAt(i); - } - } - - static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, - const Optional& device) { - const auto& t_shape = ToTVMShape(tensor.shape()); - const auto& t_type = ToTVMType(tensor.data_type()); - return TensorStructInfo(ShapeExpr(t_shape), t_type, device); - } - - static TensorStructInfo ToTensorStructInfo(const MetaTensor& tensor, const Expr& expr) { - const auto* sinfo = GetStructInfoAs(expr); - return ToTensorStructInfo(tensor, sinfo->vdevice); - } - - static bool OnDevice(DLTensor* tensor, DLDeviceType device) { - return tensor->device.device_type == device; - } - - static void CheckDevice(DLTensor* tensor, DLDeviceType device) { - ICHECK_EQ(tensor->device.device_type, device); - } - - static Device DefaultCPU() { - Device cpu_dev{kDLCPU, 0}; - return cpu_dev; - } - - static Device DefaultCUDA() { - Device cuda_dev{kDLCUDA, 0}; - return cuda_dev; - } -}; - -#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF(FuncName, Body) \ - TVM_FFI_STATIC_INIT_BLOCK() { \ - tvm::ffi::reflection::GlobalDef().def(FuncName, Body); \ - } - -#define TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED(FuncName, Body) \ - TVM_FFI_STATIC_INIT_BLOCK() { \ - tvm::ffi::reflection::GlobalDef().def_packed(FuncName, Body); \ - } - -#endif // PLUGIN_SUPPORT_TVM -""" - - -def _get_torch_utils() -> str: - """Get the utils for torch - - Returns - ------- - source: str - The plugin utils for torch. - """ - - return """ -#ifdef PLUGIN_SUPPORT_TORCH -class TorchUtils { - public: - static MetaDataType ToMetaType(const torch::ScalarType& dtype) { - MetaDataType meta_type; - if (dtype == torch::kChar) { - meta_type = MetaDataType::kINT8; - } else if (dtype == torch::kInt) { - meta_type = MetaDataType::kINT32; - } else if (dtype == torch::kInt64) { - meta_type = MetaDataType::kINT64; - } else if (dtype == torch::kLong) { - meta_type = MetaDataType::kINT64; - } else if (dtype == torch::kFloat16) { - meta_type = MetaDataType::kFLOAT16; - } else if (dtype == torch::kFloat) { - meta_type = MetaDataType::kFLOAT32; - } else if (dtype == torch::kDouble) { - meta_type = MetaDataType::kFLOAT64; - } else { - meta_type = MetaDataType::kUNKNOWN; - } - return meta_type; - } - - static MetaShape ToMetaShape(const torch::Tensor& tensor) { - std::vector shape_data; - for (size_t idx = 0; idx < tensor.dim(); idx++) { - shape_data.push_back(tensor.size(idx)); - } - return MetaShape(shape_data); - } - - static MetaTensor ToMetaTensor(const torch::Tensor& tensor, - const MetaLayout& layout = MetaLayout()) { - return MetaTensor(ToMetaShape(tensor), ToMetaType(tensor.scalar_type()), layout); - } - - template - static DataTensor ToDataTensor(const torch::Tensor& tensor, const MetaTensor& meta, - bool read_only) { - if (read_only) { - return DataTensor(meta.shape(), meta.data_type(), meta.layout(), - (const T*)(tensor.data_ptr())); - } else { - return DataTensor(meta.shape(), meta.data_type(), meta.layout(), (T*)(tensor.data_ptr())); - } - } - - static torch::ScalarType ToTorchType(const MetaDataType& dtype) { - torch::ScalarType torch_type; - if (dtype == MetaDataType::kINT8) { - torch_type = torch::kChar; - } else if (dtype == MetaDataType::kINT32) { - torch_type = torch::kInt; - } else if (dtype == MetaDataType::kINT64) { - torch_type = torch::kInt64; - } else if (dtype == MetaDataType::kFLOAT16) { - torch_type = torch::kFloat16; - } else if (dtype == MetaDataType::kFLOAT32) { - torch_type = torch::kFloat; - } else if (dtype == MetaDataType::kFLOAT64) { - torch_type = torch::kDouble; - } else { - throw std::runtime_error("Unsupported type"); - } - return torch_type; - } - - static torch::ScalarType ToTorchType(const std::string& dtype) { - return ToTorchType(DataUtils::ToMetaType(dtype)); - } - - static torch::Device ToTorchDevice(const std::string& device) { - if (device == "cpu") { - return torch::Device(torch::kCPU); - } - if (device == "cuda") { - return torch::Device(torch::kCUDA); - } - return torch::Device(torch::kCPU); - } - - static torch::Tensor MallocTorchTensor(const MetaTensor& tensor, const torch::Device& device) { - auto t_type = ToTorchType(tensor.data_type()); - auto opt = torch::TensorOptions().dtype(t_type).device(device); - return torch::zeros(tensor.meta_shape(), opt); - } -}; -#endif // PLUGIN_SUPPORT_TORCH -""" - - -def _get_tensorrt_utils() -> str: - """Get the utils for tensorrt - - Returns - ------- - source: str - The plugin utils for tensorrt. - """ - - return """ -#ifdef PLUGIN_SUPPORT_TENSORRT -using namespace nvinfer1; - -#ifndef TRT_VERSION_GE -#define TRT_VERSION_GE(major, minor, patch) \\ - ((TRT_MAJOR > major) || (TRT_MAJOR == major && TRT_MINOR > minor) || \\ - (TRT_MAJOR == major && TRT_MINOR == minor && TRT_PATCH >= patch)) -#endif - -class TRTUtils { - public: - template - static void ValToBuffer(char*& buffer, const T& val) { - *reinterpret_cast(buffer) = val; - buffer += sizeof(T); - } - - static void ValToBuffer(char*& buffer, const std::string& val) { - *reinterpret_cast(buffer) = val.size(); - buffer += sizeof(size_t); - val.copy(buffer, val.size()); - buffer += sizeof(char) * val.size(); - } - - template - static void ValToBuffer(char*& buffer, const std::vector& val) { - ValToBuffer(buffer, val.size()); - for (auto e : val) { - ValToBuffer(buffer, e); - } - } - - template - static void ValFromBuffer(const char*& buffer, T& val) { - val = *reinterpret_cast(buffer); - buffer += sizeof(T); - } - - static void ValFromBuffer(const char*& buffer, std::string& val) { - auto size = *reinterpret_cast(buffer); - buffer += sizeof(size_t); - val = std::string(reinterpret_cast(buffer), size); - buffer += sizeof(char) * size; - } - - template - static void ValFromBuffer(const char*& buffer, std::vector& val) { - size_t size; - ValFromBuffer(buffer, size); - val.resize(size); - for (size_t i = 0; i < size; i++) { - ValFromBuffer(buffer, val[i]); - } - } - - static PluginFieldType ToFieldType(const std::string& dtype) { - PluginFieldType field_type; - if (dtype == "char" || dtype == "uint8" || dtype == "string") { - field_type = PluginFieldType::kCHAR; - } else if (dtype == "int8") { - field_type = PluginFieldType::kINT8; - } else if (dtype == "int16") { - field_type = PluginFieldType::kINT16; - } else if (dtype == "int" || dtype == "int32") { - field_type = PluginFieldType::kINT32; - } else if (dtype == "float16" || dtype == "half") { - field_type = PluginFieldType::kFLOAT16; - } else if (dtype == "float32" || dtype == "float") { - field_type = PluginFieldType::kFLOAT32; - } else if (dtype == "float64" || dtype == "double") { - field_type = PluginFieldType::kFLOAT64; - } else { - field_type = PluginFieldType::kUNKNOWN; - } - return field_type; - } - - static const PluginField ToField(const std::string& name, const std::string& dtype) { - const auto& ele_type = DataUtils::GetEleType(dtype); - if (ele_type.size() == 0) { - return PluginField(name.c_str(), nullptr, ToFieldType(dtype), 1); - } - return PluginField(name.c_str(), nullptr, ToFieldType(ele_type), 11); - } - - static void FromField(const PluginField& field, std::string& val) { - assert(field.type == PluginFieldType::kCHAR); - const char* data = static_cast(field.data); - val = data; - } - - static void FromField(const PluginField& field, bool& val) { - assert(field.type == PluginFieldType::kINT32); - int int_val = *(static_cast(field.data)); - val = int_val == 0 ? false : true; - } - - static void FromField(const PluginField& field, int& val) { - assert(field.type == PluginFieldType::kINT32); - val = *(static_cast(field.data)); - } - - static void FromField(const PluginField& field, size_t& val) { - assert(field.type == PluginFieldType::kINT32); - val = *(static_cast(field.data)); - } - - static void FromField(const PluginField& field, long& val) { - assert(field.type == PluginFieldType::kINT32); - val = *(static_cast(field.data)); - } - - static void FromField(const PluginField& field, float& val) { - assert(field.type == PluginFieldType::kFLOAT32); - val = *(static_cast(field.data)); - } - - static void FromField(const PluginField& field, double& val) { - assert(field.type == PluginFieldType::kFLOAT64); - val = *(static_cast(field.data)); - } - - static MetaDataType ToMetaType(const DataType& dtype) { - MetaDataType meta_type; - if (dtype == DataType::kINT8) { - meta_type = MetaDataType::kINT8; - } else if (dtype == DataType::kINT32) { - meta_type = MetaDataType::kINT32; - } else if (dtype == DataType::kHALF) { - meta_type = MetaDataType::kFLOAT16; - } else if (dtype == DataType::kFLOAT) { - meta_type = MetaDataType::kFLOAT32; - } else { - meta_type = MetaDataType::kUNKNOWN; - } - return meta_type; - } - - static MetaShape ToMetaShape(const Dims& trt_dims, bool dynamic = false) { - std::vector dims; - if (!dynamic) { - dims.push_back(1); - } - for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { - dims.push_back(trt_dims.d[idx]); - } - return MetaShape(dims); - } - - static MetaShape ToMetaShape(const DimsExprs& trt_dims) { - std::vector dims; - for (size_t idx = 0; idx < trt_dims.nbDims; idx++) { - assert(trt_dims.d[idx]->isConstant()); - dims.push_back(trt_dims.d[idx]->getConstantValue()); - } - return MetaShape(dims); - } - - static MetaShape ToMetaShape(const PluginTensorDesc& desc) { - return ToMetaShape(desc.dims, true); - } - - static MetaShape ToMetaShape(const DynamicPluginTensorDesc& desc) { - return ToMetaShape(desc.desc); - } - - static MetaTensor ToMetaTensor(const Dims& dims, const DataType& dtype, const std::string& layout, - bool dynamic = false) { - return MetaTensor(ToMetaShape(dims, dynamic), ToMetaType(dtype), MetaLayout(layout)); - } - - static MetaTensor ToMetaTensor(const DimsExprs& dims, const DataType& dtype, - const std::string& layout) { - return MetaTensor(ToMetaShape(dims), ToMetaType(dtype), MetaLayout(layout)); - } - - static MetaTensor ToMetaTensor(const PluginTensorDesc& desc, const std::string& layout) { - return ToMetaTensor(desc.dims, desc.type, layout, true); - } - - static MetaTensor ToMetaTensor(const DynamicPluginTensorDesc& desc, const std::string& layout) { - return ToMetaTensor(desc.desc, layout); - } - - static DataType ToDataType(const MetaDataType& dtype) { - DataType data_type; - if (dtype == MetaDataType::kINT8) { - data_type = DataType::kINT8; - } else if (dtype == MetaDataType::kINT32) { - data_type = DataType::kINT32; - } else if (dtype == MetaDataType::kFLOAT16) { - data_type = DataType::kHALF; - } else if (dtype == MetaDataType::kFLOAT32) { - data_type = DataType::kFLOAT; - } else { - data_type = DataType::kFLOAT; - } - return data_type; - } - - static DataType ToDataType(const std::string& dtype) { - return ToDataType(DataUtils::ToMetaType(dtype)); - } - - static Dims ToDims(const MetaShape& meta_shape, bool dynamic = false) { - std::vector int_dims; - if (dynamic) { - int_dims.push_back(meta_shape.DimAt(0)); - } - for (size_t i = 1; i < meta_shape.ndim(); i++) { - int_dims.push_back(meta_shape.DimAt(i)); - } - Dims dims{int(int_dims.size())}; - for (size_t i = 0; i < int_dims.size(); i++) { - dims.d[i] = int_dims[i]; - } - return dims; - } - - static DimsExprs ToDimsExprs(const MetaShape& meta_shape, IExprBuilder& builder) { - std::vector int_dims; - for (size_t i = 0; i < meta_shape.ndim(); i++) { - int_dims.push_back(meta_shape.DimAt(i)); - } - DimsExprs dims{int(int_dims.size())}; - for (size_t i = 0; i < int_dims.size(); i++) { - dims.d[i] = builder.constant(int_dims[i]); - } - return dims; - } - - static const MetaShape SetBatch(const MetaTensor& tensor, int batch_size) { - MetaShape shape = tensor.shape(); - int batch = tensor.AxisOf("N"); - if (batch < 0) { - batch = 0; - } - shape.SetDim(batch, batch_size); - return shape; - } - - template - static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, const void* data) { - const auto& shape = SetBatch(tensor, batch_size); - return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); - } - - template - static DataTensor ToDataTensor(const MetaTensor& tensor, int batch_size, void* data) { - const auto& shape = SetBatch(tensor, batch_size); - return DataTensor(shape, tensor.data_type(), tensor.layout(), (const T*)(data)); - } - - template - static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, - const void* data) { - return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), - (const T*)(data)); - } - - template - static DataTensor ToDataTensor(const MetaTensor& tensor, const PluginTensorDesc& desc, - void* data) { - return DataTensor(ToMetaShape(desc), ToMetaType(desc.type), tensor.layout(), (T*)(data)); - } -}; -#endif // PLUGIN_SUPPORT_TENSORRT -""" - - -def get_plugin_utils_h_code() -> str: - """Create plugin utils header file codes - - Returns - ------- - source: str - The plugin utils header source. - """ - - code = """#ifndef TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ -#define TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "plugin_base.h" - -#ifdef PLUGIN_ENABLE_CUDA -#include -#include -#endif // PLUGIN_ENABLE_CUDA - -#ifdef PLUGIN_SUPPORT_TVM -#include -#include - -#include "tvm/../../src/contrib/msc/core/transform/layout_utils.h" -#include "tvm/../../src/contrib/msc/core/utils.h" -#ifdef PLUGIN_ENABLE_CUDA -#include "tvm/../../src/runtime/cuda/cuda_common.h" -#endif // PLUGIN_ENABLE_CUDA -#endif // PLUGIN_SUPPORT_TVM - -#ifdef PLUGIN_SUPPORT_TORCH -#include -#include -#ifdef PLUGIN_ENABLE_CUDA -#include -#endif // PLUGIN_ENABLE_CUDA -#endif // PLUGIN_SUPPORT_TORCH - -#ifdef PLUGIN_SUPPORT_TENSORRT -#include "NvInfer.h" -#endif // PLUGIN_SUPPORT_TENSORRT - -namespace tvm { -namespace contrib { -namespace msc { -namespace plugin { - -""" - code += _get_common_utils() - code += _get_tvm_utils() - code += _get_torch_utils() - code += _get_tensorrt_utils() - code += """ -} // namespace plugin -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_UTILS_PLUGIN_UTILS_H_ -""" - return code - - -def get_plugin_sources() -> Dict[str, str]: - """Create base sources for plugin codegen - - Returns - ------- - sources: dict - The base utils sources. - """ - - return { - "plugin_base.h": get_plugin_base_h_code(), - "plugin_utils.h": get_plugin_utils_h_code(), - } diff --git a/python/tvm/contrib/msc/plugin/op/__init__.py b/python/tvm/contrib/msc/plugin/op/__init__.py deleted file mode 100644 index 6b306c8c1f5b..000000000000 --- a/python/tvm/contrib/msc/plugin/op/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.op""" diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py deleted file mode 100644 index 8ca5071cdaf6..000000000000 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.op._ffi_api""" - -import tvm_ffi - -tvm_ffi.init_ffi_api("msc.plugin.op", __name__) diff --git a/python/tvm/contrib/msc/plugin/register.py b/python/tvm/contrib/msc/plugin/register.py deleted file mode 100644 index 782988c0f162..000000000000 --- a/python/tvm/contrib/msc/plugin/register.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.register""" - -import os -from typing import Dict - -import tvm -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import utils as msc_utils - - -def register_plugin( - name: str, plugin: dict, externs_dir: msc_utils.MSCDirectory = None -) -> Dict[str, str]: - """Register a plugin - - Parameters - ---------- - name: str - The name of the plugin. - plugin: dict - The define of a plugin. - externs_dir: MSCDirectory - The extern sources folder. - - Returns - ------- - depend_files: dict - The depend file paths. - """ - - plugin = {"name": name, **msc_utils.load_dict(plugin)} - assert "externs" in plugin, "externs are needed to build plugin" - # check device compute - remove_externs = set() - for extern in plugin["externs"]: - if extern == "cuda_compute" and not tvm.cuda().exist: - remove_externs.add(extern) - if remove_externs: - plugin["externs"] = {k: v for k, v in plugin["externs"].items() if k not in remove_externs} - externs = plugin["externs"] - - def _check_file(info: dict, key: str) -> str: - if key not in info: - return None - file_path = info[key] - if os.path.abspath(file_path) != file_path: - assert externs_dir, "externs_dir is need to find file " + str(file_path) - file_path = externs_dir.relpath(file_path) - assert os.path.isfile(file_path), "Can not find externs file " + str(file_path) - info[key] = os.path.basename(file_path) - return file_path - - # find depend files - extern_sources, extern_libs = {}, {} - for info in externs.values(): - for key in ["header", "source"]: - file_path = _check_file(info, key) - if file_path: - extern_sources[os.path.basename(file_path)] = file_path - file_path = _check_file(info, "lib") - if file_path: - extern_libs[os.path.basename(file_path)] = file_path - _ffi_api.RegisterPlugin(name, msc_utils.dump_dict(plugin)) - # remove needless keys - for key in ["support_dtypes", "externs"]: - plugin.pop(key) - plugin["inputs"] = [{"name": i["name"]} for i in plugin["inputs"]] - plugin["outputs"] = [{"name": o["name"]} for o in plugin["outputs"]] - return extern_sources, extern_libs, plugin diff --git a/python/tvm/contrib/msc/plugin/utils.py b/python/tvm/contrib/msc/plugin/utils.py deleted file mode 100644 index c933d54b9d66..000000000000 --- a/python/tvm/contrib/msc/plugin/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""tvm.contrib.msc.plugin.utils""" - -import os -from typing import Any - -from tvm import relax, tir -from tvm.contrib.msc.core import utils as msc_utils - - -def to_expr(value: Any) -> relax.Expr: - """Change value to expr - - Parameters - ---------- - value: - The value with python type. - - Returns - ------- - expr: relax.Expr - The relax Expr. - """ - - if isinstance(value, (bool, int)): - value = tir.IntImm("int64", value) - expr = relax.PrimValue(value) - elif isinstance(value, float): - value = tir.FloatImm("float64", value) - expr = relax.PrimValue(value) - elif isinstance(value, str): - expr = relax.StringImm(value) - elif isinstance(value, (list, tuple)): - expr = relax.Tuple([to_expr(v) for v in value]) - else: - raise TypeError(f"Unsupported input type: {type(value)}") - return expr - - -def export_plugins(plugins: dict, folder: msc_utils.MSCDirectory) -> dict: - """Export the plugins - - Parameters - ---------- - plugins: dict - The plugins. - folder: MSCDirectory - The export folder. - - Returns - ------- - info: dict - The loadable plugins info. - """ - - if not plugins: - return {} - info = {} - for name, plugin in plugins.items(): - with folder.create_dir(name) as sub_folder: - info[name] = sub_folder.path - plugin.export(info[name]) - return info - - -def load_plugins(info: dict) -> dict: - """Load the plugins - - Parameters - ---------- - info: dict - The plugins info. - - Returns - ------- - plugins: dict - The plugins. - """ - - if not info: - return {} - plugins = {} - for name, plugin in info.items(): - if isinstance(plugin, str): - manager_file = os.path.join(plugin, "manager.py") - assert os.path.isfile(manager_file), "Can not find manager file for plugin: " + str( - manager_file - ) - manager_cls = msc_utils.load_callable(manager_file + ":PluginManager") - plugins[name] = manager_cls(plugin) - else: - plugins[name] = plugin - return plugins diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h deleted file mode 100644 index bb5b5be058c6..000000000000 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/base_codegen.h - * \brief Basic CodeGen for MSCGraph and MSCJoint. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_BASE_CODEGEN_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_BASE_CODEGEN_H_ - -#include -#include - -#include -#include -#include -#include - -#include "../ir/graph.h" -#include "../ir/plugin.h" -#include "code_stack.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief CodeGen for MSCJoint op - */ -template -class BaseOpCode { - public: - /*! - * \brief The constructor of BaseOpCode - * \param func_name the function name for the node. - */ - explicit BaseOpCode(const ffi::String& func_name) : func_name_(func_name) {} - - virtual ~BaseOpCode() = default; - - /*! \brief Config the BaseOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config, - const ffi::Map& prims) { - node_ = node; - config_ = config; - prims_ = prims; - } - - /*! \brief Get docs for the node*/ - virtual const ffi::Array GetDocs() = 0; - - /*! \brief Get return describe for default node*/ - virtual const ffi::String IdxNode() { return IdxNodeBase(node_); } - - /*! \brief Get describe for default node input*/ - const ffi::String IdxInput(int idx = 0, bool process = true) { - return IdxInputBase(node_, idx, process); - } - - /*! \brief Get describe for default node output*/ - const ffi::String IdxOutput(int idx = 0) { return IdxOutputBase(node_, idx); } - - /*! \brief Get describe for default node weight*/ - const ffi::String IdxWeight(const ffi::String& wtype, bool process = true) { - return IdxWeightBase(node_, wtype, process); - } - - /*! \brief Get the node attr as doc*/ - const ExprDoc GetAttrDoc(const ffi::String& key, const ffi::String& type) { - if (StringUtils::StartsWith(type, "list")) { - const ffi::String& ele_type = - StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); - if (ele_type == "bool") { - return DocUtils::ToList(node_->GetTypeArrayAttr(key)); - } else if (ele_type == "int" || ele_type == "int32") { - return DocUtils::ToList(node_->GetTypeArrayAttr(key)); - } else if (ele_type == "long" || ele_type == "int64") { - return DocUtils::ToList(node_->GetTypeArrayAttr(key)); - } else if (ele_type == "float" || ele_type == "float32") { - return DocUtils::ToList(node_->GetTypeArrayAttr(key)); - } else if (ele_type == "string") { - return DocUtils::ToStrList(node_->GetTypeArrayAttr(key)); - } - } else if (type == "bool") { - return DocUtils::ToDoc(node_->GetTypeAttr(key)); - } else if (type == "int" || type == "int32") { - return DocUtils::ToDoc(node_->GetTypeAttr(key)); - } else if (type == "long" || type == "int64") { - return DocUtils::ToDoc(node_->GetTypeAttr(key)); - } else if (type == "float" || type == "float32") { - return DocUtils::ToDoc(node_->GetTypeAttr(key)); - } else if (type == "string") { - return DocUtils::ToStr(node_->GetTypeAttr(key)); - } - return DocUtils::ToDoc(node_->GetTypeAttr(key)); - } - - /*! \brief Get comment for default node*/ - const ffi::String Comment() { return Comment(node_); } - - /*! \brief Get func_name for the default node*/ - const ffi::String func_name() { return func_name_; } - - /*! \brief Get valid func name for the default node*/ - virtual const ffi::String callee_name() { return func_name(); } - - /*! \brief Get valid return name for the default node*/ - virtual const ffi::String ret_name() { return IdxNode(); } - - /*! \brief Get the default node*/ - const MSCJoint node() { return node_; } - - CODEGEN_MEMBERS; - - private: - ffi::String func_name_; - MSCJoint node_; -}; - -/*! - * \brief CodeGen for MSCGraph - */ -template -class BaseCodeGen { - public: - /*! - * \brief The constructor of BaseCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit BaseCodeGen(const MSCGraph& graph, const std::string& config = "") { - graph_ = graph; - config_.reset(new ConfigType()); - if (config.size() > 0) { - namespace json = ::tvm::ffi::json; - config_->Load(json::Parse(config).cast()); - } - while (!scopes_.empty()) { - scopes_.pop(); - } - } - - virtual void Init() { - // define prims - for (const auto& p_name : this->graph()->prim_names) { - prims_.Set(p_name, this->DescribePrim(this->graph()->FindPrim(p_name))); - } - } - - virtual ~BaseCodeGen() = default; - - /*! \brief Get sources*/ - virtual const ffi::Map GetSources( - const std::string& print_options = "") = 0; - - CODEGEN_MEMBERS; - - protected: - /*! - * \brief Compare node scope with current scope - * 0 for same scope, 1 for increase scope, -1 for decrease scope - */ - int CompareScope(const MSCJoint& node) { - if (node->scope.size() == 0) { - return 0; - } - if (scopes_.size() == 0) { - scopes_.push(node->scope); - return 1; - } - if (node->scope.size() == scopes_.top().size()) { - TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top())) - << "Scope mismatch, node " << node->scope << " compare to current " << scopes_.top(); - return 0; - } else if (node->scope.size() == scopes_.top().size() + 1) { - TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) - << "Scope increase mismatch, node " << node->scope << " compare to current " - << scopes_.top(); - scopes_.push(node->scope); - return 1; - } else if (node->scope.size() == scopes_.top().size() - 1) { - TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) - << "Scope decrease mismatch, node " << node->scope << " compare to current " - << scopes_.top(); - scopes_.pop(); - return -1; - } else { - TVM_FFI_THROW(InternalError) - << "Unexpected node scope " << node->scope << " with current scope " << scopes_.top(); - } - } - - /*! \brief Get the optype for op codegen*/ - const ffi::String GetOpType(const MSCJoint& node) { - if (config_->use_plugin && IsPlugin(node->optype)) { - return "plugin"; - } - return node->optype; - } - - /*! \brief Get the docs for the op*/ - virtual const ffi::Array GetOpCodes(const MSCJoint& node) = 0; - - /*! \brief Describe the prim*/ - virtual const ffi::String DescribePrim(const MSCPrim& prim) { - if (prim->optype == "Int") { - return prim->GetTypeAttr("value"); - } - if (prim->optype == "shape") { - const auto& producer = this->graph()->FindNode(prim->GetTypeAttr("producer")); - int out_idx = prim->GetTypeAttr("out_idx"); - const auto& dim = prim->GetTypeAttr("dim"); - return this->IdxOutputBase(producer, out_idx) + ".shape[" + dim + "]"; - } - // binary ops - DESCRIBE_PRIM_BINARY("Add", "+", false) - DESCRIBE_PRIM_BINARY("Sub", "-", false) - DESCRIBE_PRIM_BINARY("Mul", "*", false) - DESCRIBE_PRIM_BINARY("Divide", "/", false) - DESCRIBE_PRIM_BINARY("LT", "<", false) - DESCRIBE_PRIM_BINARY("LE", "<=", false) - DESCRIBE_PRIM_BINARY("GT", ">", false) - DESCRIBE_PRIM_BINARY("GE", ">=", false) - LOG_FATAL << "Unexpected prim " << prim; - } - - /*! \brief Get the graph*/ - const MSCGraph graph() const { return graph_; } - - /*! \brief Get the scopes*/ - const std::stack> scopes() const { return scopes_; } - - /*! \brief The stack of codes*/ - CodeStack stack_; - - private: - MSCGraph graph_; - std::stack> scopes_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_BASE_CODEGEN_H_ diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc deleted file mode 100644 index e31726f647af..000000000000 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ /dev/null @@ -1,479 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/code_stack.cc - */ - -#include "code_stack.h" - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::Array BaseStack::GetDocs() const { - TVM_FFI_ICHECK(blocks_.size() == 1) << "Has incomplete blocks, please check"; - return TopBlock(); -} - -void BaseStack::Line(const Doc& doc) { PushDoc(doc); } - -void BaseStack::Line(const ffi::String& line) { Line(IdDoc(line)); } - -void BaseStack::Comment(const ffi::String& comment, bool attach) { - if (attach) { - const auto& doc = TopDoc(); - TVM_FFI_ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; - const auto& stmt = Downcast(doc); - stmt->comment = comment; - } else { - PushDoc(CommentDoc(comment)); - } -} - -void BaseStack::Declare(const ffi::String& type, const ffi::String& variable, size_t len, - bool use_constructor) { - PushDoc(DocUtils::ToDeclare(type, variable, len, use_constructor)); -} - -void BaseStack::DeclareArgBase(const ExprDoc& value) { - const auto& declare = PopCheckedDoc(); - ffi::Array init_args = declare->init_args; - init_args.push_back(value); - PushDoc(DeclareDoc(declare->type, declare->variable, init_args, declare->use_constructor)); -} - -void BaseStack::FuncDef(const ffi::String& func_name, const ffi::String& ret_type) { - if (ret_type.size() > 0) { - PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), - IdDoc(ret_type), ffi::Array())); - } else { - PushDoc(FunctionDoc(IdDoc(func_name), ffi::Array(), ffi::Array(), - std::nullopt, ffi::Array())); - } -} - -void BaseStack::FuncArg(const ffi::String& arg, const ffi::String& annotation, - const ffi::String& value) { - const auto& func = PopCheckedDoc(); - ffi::Array args = func->args; - args.push_back(DocUtils::ToAssign(arg, value, annotation)); - PushDoc(FunctionDoc(func->name, args, func->decorators, func->return_type, func->body)); -} - -void BaseStack::FuncDecorator(const ffi::String& decorator) { - const auto& func = PopCheckedDoc(); - ffi::Array decorators = func->decorators; - decorators.push_back(IdDoc(decorator)); - PushDoc(FunctionDoc(func->name, func->args, decorators, func->return_type, func->body)); -} - -void BaseStack::FuncStart() { - TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "FunctionDoc is not saved"; - BlockStart(); -} - -void BaseStack::FuncEnd() { - const auto& block = PopBlock(); - const auto& func = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(FunctionDoc(func->name, func->args, func->decorators, func->return_type, body)); -} - -void BaseStack::ClassDef(const ffi::String& class_name) { - PushDoc(ClassDoc(IdDoc(class_name), ffi::Array(), ffi::Array())); -} - -void BaseStack::ClassDecorator(const ffi::String& decorator) { - const auto& class_doc = PopCheckedDoc(); - ffi::Array decorators = class_doc->decorators; - decorators.push_back(IdDoc(decorator)); - PushDoc(ClassDoc(class_doc->name, decorators, class_doc->body)); -} - -void BaseStack::ClassStart() { - TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "ClassDoc is not saved"; - BlockStart(); -} - -void BaseStack::ClassEnd() { - const auto& block = PopBlock(); - const auto& class_doc = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(ClassDoc(class_doc->name, class_doc->decorators, body)); -} - -void BaseStack::StructStart(const ffi::String& struct_name) { - PushDoc(StructDoc(IdDoc(struct_name), ffi::Array(), ffi::Array())); - BlockStart(); -} - -void BaseStack::StructEnd() { - const auto& block = PopBlock(); - const auto& struct_doc = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(StructDoc(struct_doc->name, struct_doc->decorators, body)); -} - -void BaseStack::ConstructorDef(const ffi::String& constructor_name) { - PushDoc(ConstructorDoc(IdDoc(constructor_name), ffi::Array(), ffi::Array())); -} - -void BaseStack::ConstructorArg(const ffi::String& arg, const ffi::String& annotation, - const ffi::String& value) { - const auto& func = PopCheckedDoc(); - ffi::Array args = func->args; - args.push_back(DocUtils::ToAssign(arg, value, annotation)); - PushDoc(ConstructorDoc(func->name, args, func->body)); -} - -void BaseStack::ConstructorStart() { - TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "ConstructorDoc is not saved"; - BlockStart(); -} - -void BaseStack::ConstructorEnd() { - const auto& block = PopBlock(); - const auto& func = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(ConstructorDoc(func->name, func->args, body)); -} - -void BaseStack::LambdaDef(const ffi::String& lambda_name) { - PushDoc(LambdaDoc(IdDoc(lambda_name), ffi::Array(), ffi::Array(), - ffi::Array())); -} - -void BaseStack::LambdaArg(const ffi::String& arg, const ffi::String& annotation, - const ffi::String& value) { - const auto& lambda = PopCheckedDoc(); - ffi::Array args = lambda->args; - args.push_back(DocUtils::ToAssign(arg, value, annotation)); - PushDoc(LambdaDoc(lambda->name, args, lambda->refs, lambda->body)); -} - -void BaseStack::LambdaRef(const ffi::String& ref) { - const auto& lambda = PopCheckedDoc(); - ffi::Array refs = lambda->refs; - refs.push_back(IdDoc(ref)); - PushDoc(LambdaDoc(lambda->name, lambda->args, refs, lambda->body)); -} - -void BaseStack::LambdaStart() { - TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "LambdaDoc is not saved"; - BlockStart(); -} - -void BaseStack::LambdaEnd(const ffi::String& ret_val) { - if (ret_val.size() > 0) { - PushDoc(ReturnDoc(IdDoc(ret_val))); - } - const auto& block = PopBlock(); - const auto& lambda = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(LambdaDoc(lambda->name, lambda->args, lambda->refs, body)); -} - -void BaseStack::LambdaEnd(const ExprDoc& ret_val) { - PushDoc(ReturnDoc(ret_val)); - LambdaEnd(""); -} - -void BaseStack::FuncCall(const ffi::String& callee, ffi::Optional assign_to, - ffi::Optional caller) { - if (!caller.defined()) { - PushDoc(CallDoc(IdDoc(callee), ffi::Array(), ffi::Array(), - ffi::Array())); - } else { - const auto& new_access = AttrAccessDoc(caller.value(), callee); - PushDoc(CallDoc(new_access, ffi::Array(), ffi::Array(), - ffi::Array())); - } - if (assign_to.defined()) { - const auto& last_call = PopCheckedDoc(); - if (assign_to.value()->IsInstance()) { - const auto& declare = Downcast(assign_to.value()); - PushDoc(AssignDoc(declare->variable, last_call, declare->type)); - } else { - const auto& declare = DocUtils::ToDeclare("", assign_to.value()); - PushDoc(AssignDoc(declare->variable, last_call, declare->type)); - } - } -} - -void BaseStack::FuncCall(const ffi::String& callee, const ffi::String& assign_to, - const ffi::String& caller) { - ffi::Optional assign_doc; - if (assign_to.size() == 0) { - assign_doc = std::nullopt; - } else { - assign_doc = IdDoc(assign_to); - } - ffi::Optional caller_doc; - if (caller.size() == 0) { - caller_doc = std::nullopt; - } else { - caller_doc = IdDoc(caller); - } - FuncCall(callee, assign_doc, caller_doc); -} - -void BaseStack::MethodCall(const ffi::String& callee, bool new_line) { - const auto& host = PopDoc(); - if (host->IsInstance()) { - const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); - FuncCall(v_callee, std::nullopt, Downcast(host)); - } else if (const auto* a_node = host.as()) { - TVM_FFI_ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; - FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, ffi::Array(), true), - a_node->rhs); - } else { - TVM_FFI_THROW(InternalError) << "Unexpected host type for inplace " << host->GetTypeKey(); - } -} - -void BaseStack::InplaceStart(const ffi::String& callee, ffi::Optional assign_to, - ffi::Optional caller) { - FuncCall(callee, assign_to, caller); -} - -void BaseStack::InplaceStart(const ffi::String& callee, const ffi::String& assign_to, - const ffi::String& caller) { - FuncCall(callee, assign_to, caller); -} - -void BaseStack::InplaceEnd() { - const auto& last = PopDoc(); - // get args and kwargs - if (last->IsInstance()) { - CallArgBase(Downcast(last)); - } else if (const auto* assign = last.as()) { - const auto& call = Downcast(assign->rhs); - TVM_FFI_ICHECK(assign->lhs->IsInstance()) - << "assign lhs should be IdDoc, get " << assign->lhs->GetTypeKey(); - const auto& key = Downcast(assign->lhs)->name; - CallArgBase(call, key); - } else { - TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); - } -} - -void BaseStack::PopNest(const ffi::String& key) { - const auto& last = PopDoc(); - if (last->IsInstance()) { - CallArgBase(Downcast(last), key); - } else { - TVM_FFI_THROW(InternalError) << "Unexpected nest type " << last->GetTypeKey(); - } -} - -void BaseStack::CallArgBase(const ExprDoc& value, const ffi::String& key) { - const auto& last = PopDoc(); - ffi::Array args; - ffi::Array kwargs_keys; - ffi::Array kwargs_values; - // get args and kwargs - if (const auto* call = last.as()) { - args = call->args; - kwargs_keys = call->kwargs_keys; - kwargs_values = call->kwargs_values; - } else if (const auto* assign = last.as()) { - const auto& call = Downcast(assign->rhs); - args = call->args; - kwargs_keys = call->kwargs_keys; - kwargs_values = call->kwargs_values; - } else { - TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); - } - // push args or kwargs - if (key.size() == 0) { - TVM_FFI_ICHECK(kwargs_keys.size() == 0) << "kwargs followed by args " << value; - args.push_back(value); - } else { - kwargs_keys.push_back(key); - kwargs_values.push_back(value); - } - // push doc - if (const auto* call = last.as()) { - PushDoc(CallDoc(call->callee, args, kwargs_keys, kwargs_values)); - } else if (const auto* assign = last.as()) { - const auto& call = Downcast(assign->rhs); - const auto& new_call = CallDoc(call->callee, args, kwargs_keys, kwargs_values); - PushDoc(AssignDoc(assign->lhs, new_call, assign->annotation)); - } else { - TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); - } -} - -void BaseStack::ConditionIf(const ffi::String& predicate) { - ffi::Array else_branch{ExprStmtDoc(IdDoc("pass"))}; - PushDoc(IfDoc(IdDoc(predicate), ffi::Array(), else_branch)); - BlockStart(); -} - -void BaseStack::ConditionElse() { - const auto& block = PopBlock(); - const auto& if_doc = PopCheckedDoc(); - PushDoc(IfDoc(if_doc->predicate, DocUtils::ToStmts(block), ffi::Array())); - BlockStart(); -} - -void BaseStack::ConditionEnd() { - const auto& block = PopBlock(); - const auto& if_doc = PopCheckedDoc(); - const auto& branch = DocUtils::ToStmts(block); - if (if_doc->then_branch.size() == 0) { - PushDoc(IfDoc(if_doc->predicate, branch, ffi::Array())); - } else { - PushDoc(IfDoc(if_doc->predicate, if_doc->then_branch, branch)); - } -} - -void BaseStack::ForEnd() { - const auto& block = PopBlock(); - const auto& for_doc = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(ForDoc(for_doc->lhs, for_doc->rhs, body)); -} - -void BaseStack::WhileStart(const ffi::String& predicate) { - PushDoc(WhileDoc(IdDoc(predicate), ffi::Array())); - BlockStart(); -} - -void BaseStack::WhileEnd() { - const auto& block = PopBlock(); - const auto& while_doc = PopCheckedDoc(); - const auto& body = DocUtils::ToStmts(block); - PushDoc(WhileDoc(while_doc->predicate, body)); -} - -void BaseStack::SwitchStart(const ffi::String& predicate) { - ffi::Array predicates; - predicates.push_back(IdDoc(predicate)); - PushDoc(SwitchDoc(predicates, ffi::Array>(), ffi::Array())); - BlockStart(); -} - -void BaseStack::SwitchCase(const ffi::String& predicate) { - const auto& block = PopBlock(); - const auto& switch_doc = PopCheckedDoc(); - auto branchs = switch_doc->branchs; - branchs.push_back(DocUtils::ToStmts(block)); - if (predicate.size() == 0) { - ffi::Array default_branch{ExprStmtDoc(IdDoc("pass"))}; - PushDoc(SwitchDoc(switch_doc->predicates, branchs, default_branch)); - } else { - auto predicates = switch_doc->predicates; - predicates.push_back(IdDoc(predicate)); - PushDoc(SwitchDoc(predicates, branchs, switch_doc->default_branch)); - } - BlockStart(); -} - -void BaseStack::SwitchEnd() { - const auto& block = PopBlock(); - const auto& switch_doc = PopCheckedDoc(); - if (switch_doc->default_branch.size() > 0) { - PushDoc(SwitchDoc(switch_doc->predicates, switch_doc->branchs, DocUtils::ToStmts(block))); - } else { - auto branchs = switch_doc->branchs; - branchs.push_back(DocUtils::ToStmts(block)); - PushDoc(SwitchDoc(switch_doc->predicates, branchs, switch_doc->default_branch)); - } -} - -void BaseStack::BlockStart() { - ffi::Array block; - blocks_.push(block); -} - -void BaseStack::BlockEnd(bool block_docs) { - const auto& docs = PopBlock(); - if (block_docs) { - PushDoc(DocUtils::ToStmtBlock(docs)); - } else { - for (const auto& d : docs) { - PushDoc(d); - } - } -} - -void BaseStack::ScopeStart(const ffi::String& scope_def, const ffi::String& scope_ref) { - if (scope_ref.size() > 0) { - PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), ffi::Array())); - } else { - PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), ffi::Array())); - } - BlockStart(); -} - -void BaseStack::ScopeEnd() { - const auto& block = PopBlock(); - const auto& scope = PopCheckedDoc(); - PushDoc(ScopeDoc(scope->lhs, scope->rhs, DocUtils::ToStmts(block))); -} - -bool BaseStack::HasBlock() const { return blocks_.size() > 0; } - -const ffi::Array BaseStack::TopBlock() const { - TVM_FFI_ICHECK(HasBlock()) << "No block found"; - return blocks_.top(); -} - -const ffi::Array BaseStack::PopBlock() { - const auto& block = TopBlock(); - blocks_.pop(); - return block; -} - -bool BaseStack::HasDoc() { - if (!HasBlock()) { - return false; - } - return TopBlock().size() > 0; -} - -const Doc BaseStack::TopDoc() { - TVM_FFI_ICHECK(HasDoc()) << "No doc or block found"; - return TopBlock().back(); -} - -const Doc BaseStack::PopDoc() { - const auto& doc = TopDoc(); - blocks_.top().pop_back(); - return doc; -} - -template -const TDoc BaseStack::PopCheckedDoc() { - TVM_FFI_ICHECK(HasDoc() && TopDoc()->IsInstance()) - << "Last doc(" << TopDoc()->GetTypeKey() << ") is not expected type " - << TDocNode::TypeIndex2Key(TDocNode::RuntimeTypeIndex()); - return Downcast(PopDoc()); -} - -void BaseStack::PushDoc(const Doc& doc) { - TVM_FFI_ICHECK(HasBlock()) << "No block found"; - blocks_.top().push_back(doc); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h deleted file mode 100644 index d588c3cf4f31..000000000000 --- a/src/contrib/msc/core/codegen/code_stack.h +++ /dev/null @@ -1,652 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/code_stack.h - * \brief CodeStack for doc printer. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CODE_STACK_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_CODE_STACK_H_ - -#include - -#include -#include -#include - -#include "../printer/msc_doc.h" -#include "../printer/print_utils.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief Inner class for doc stack - */ -class BaseStack { - public: - /*! - * \brief The constructor of CodeStack - */ - BaseStack() { Reset(); } - - /*! \brief Cleanup blocks*/ - void Reset() { - while (!blocks_.empty()) { - blocks_.pop(); - } - BlockStart(); - } - - /*! \brief Get the docs*/ - const ffi::Array GetDocs() const; - - protected: - /*! \brief Push Id Doc*/ - void Line(const Doc& doc); - void Line(const ffi::String& line = ""); - - /*! \brief Push Comment Doc*/ - void Comment(const ffi::String& comment, bool attach = false); - - /*! \brief Push Assign Doc*/ - template - inline void Assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { - PushDoc(DocUtils::ToAssign(lhs, rhs, annotation)); - } - - /*! \brief Push declare Doc*/ - void Declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, - bool use_constructor = true); - - /*! \brief Cache declare argument*/ - void DeclareArgBase(const ExprDoc& value); - - /*! \brief Cache declare typed argument*/ - template - inline void DeclareArg(const T& value) { - DeclareArgBase(DocUtils::ToDoc(value)); - } - - /*! \brief Cache class Doc*/ - void ClassDef(const ffi::String& class_name); - - /*! \brief Cache class decorator*/ - void ClassDecorator(const ffi::String& decorator); - - /*! \brief Start class body block*/ - void ClassStart(); - - /*! \brief End class body block*/ - void ClassEnd(); - - /*! \brief Start struct body block*/ - void StructStart(const ffi::String& struct_name); - - /*! \brief End struct body block*/ - void StructEnd(); - - /*! \brief Cache function Doc*/ - void FuncDef(const ffi::String& func_name, const ffi::String& ret_type = ""); - - /*! \brief Cache function argument*/ - void FuncArg(const ffi::String& arg, const ffi::String& annotation = "", - const ffi::String& value = ""); - - /*! \brief Cache function decorator*/ - void FuncDecorator(const ffi::String& decorator); - - /*! \brief Start function body block*/ - void FuncStart(); - - /*! \brief End function body block*/ - void FuncEnd(); - - template - void FuncEnd(const T& ret_val) { - PushDoc(ReturnDoc(DocUtils::ToDoc(ret_val))); - FuncEnd(); - } - - /*! \brief Cache constructor Doc*/ - void ConstructorDef(const ffi::String& constructor_name); - - /*! \brief Cache constructor argument*/ - void ConstructorArg(const ffi::String& arg, const ffi::String& annotation = "", - const ffi::String& value = ""); - - /*! \brief Start constructor body block*/ - void ConstructorStart(); - - /*! \brief End constructor body block*/ - void ConstructorEnd(); - - /*! \brief Cache lambda Doc*/ - void LambdaDef(const ffi::String& lambda_name); - - /*! \brief Cache lambda argument*/ - void LambdaArg(const ffi::String& arg, const ffi::String& annotation = "", - const ffi::String& value = ""); - - /*! \brief Cache lambda reference*/ - void LambdaRef(const ffi::String& ref); - - /*! \brief Start lambda body block*/ - void LambdaStart(); - - /*! \brief End lambda body block*/ - void LambdaEnd(const ffi::String& ret_val = ""); - void LambdaEnd(const ExprDoc& ret_val); - - /*! \brief Push call and maybe assign Doc*/ - void FuncCall(const ffi::String& callee, ffi::Optional assign_to, - ffi::Optional caller = std::nullopt); - void FuncCall(const ffi::String& callee, const ffi::String& assign_to = "", - const ffi::String& caller = ""); - - /*! \brief Push method call Doc*/ - void MethodCall(const ffi::String& callee, bool new_line = false); - - /*! \brief Push inplace call and maybe assign Doc*/ - void InplaceStart(const ffi::String& callee, ffi::Optional assign_to, - ffi::Optional caller = std::nullopt); - void InplaceStart(const ffi::String& callee, const ffi::String& assign_to = "", - const ffi::String& caller = ""); - - /*! \brief End inplace call*/ - void InplaceEnd(); - - /*! \brief Push nested expr to last Doc*/ - void PopNest(const ffi::String& key = ""); - - /*! \brief Cache call typed argument*/ - void CallArgBase(const ExprDoc& value, const ffi::String& key = ""); - - /*! \brief Cache call normal argument*/ - template - inline void CallArg(T value, const ffi::String& key = "") { - const auto& doc_value = DocUtils::ToDoc(value); - if (doc_value.defined()) { - CallArgBase(doc_value, key); - } - } - inline void CallArg(const ffi::Array& values) { - for (const auto& v : values) { - if (v.defined()) { - CallArgBase(v); - } - } - } - - /*! \brief Push if to cache and start if block*/ - void ConditionIf(const ffi::String& predicate); - - /*! \brief Push then branch to cached and start block*/ - void ConditionElse(); - - /*! \brief Push else branch to cached*/ - void ConditionEnd(); - - /*! \brief Push for to cache and start for block*/ - template - void ForStart(const LT& lhs, const RT& rhs) { - PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), ffi::Array())); - BlockStart(); - } - - /*! \brief Push for range to cache and start for block*/ - template - void ForStart(const ffi::String& lhs, const ST& start, const ET& end) { - ffi::Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; - PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), ffi::Array())); - BlockStart(); - } - - /*! \brief End a for block*/ - void ForEnd(); - - /*! \brief Push while to cache and start while block*/ - void WhileStart(const ffi::String& predicate); - - /*! \brief End a while block*/ - void WhileEnd(); - - /*! \brief Push switch to cache and start switch block*/ - void SwitchStart(const ffi::String& predicate); - - /*! \brief Add new case to switch*/ - void SwitchCase(const ffi::String& predicate = ""); - - /*! \brief Push switch to cached*/ - void SwitchEnd(); - - /*! \brief Start a new block*/ - void BlockStart(); - - /*! \brief End a block*/ - void BlockEnd(bool block_docs = true); - - /*! \brief Start a new scope*/ - void ScopeStart(const ffi::String& scope_def = "", const ffi::String& scope_ref = ""); - - /*! \brief End a scope*/ - void ScopeEnd(); - - private: - /*! \brief Check if has block left*/ - bool HasBlock() const; - - /*! \brief Get the last the block*/ - const ffi::Array TopBlock() const; - - /*! \brief Pop last the block*/ - const ffi::Array PopBlock(); - - /*! \brief Check if doc left*/ - bool HasDoc(); - - /*! \brief Get the last doc*/ - const Doc TopDoc(); - - /*! \brief Pop last doc*/ - const Doc PopDoc(); - - /*! \brief Pop last doc with type checked*/ - template - const TDoc PopCheckedDoc(); - - /*! \brief Push doc*/ - void PushDoc(const Doc& doc); - - /*! \brief The blocks, each has docs array*/ - std::stack> blocks_; -}; - -#define COMMON_WRAPPERS(Stack) \ - Stack& line(const Doc& doc) { \ - Line(doc); \ - return *this; \ - } \ - Stack& line(const ffi::String& line = "") { \ - Line(line); \ - return *this; \ - } \ - Stack& comment(const ffi::String& comment, bool attach = false) { \ - Comment(comment, attach); \ - return *this; \ - } \ - template \ - Stack& assign(const LT& lhs, const RT& rhs, const ffi::String& annotation = "") { \ - Assign(lhs, rhs, annotation); \ - return *this; \ - } \ - Stack& declare(const ffi::String& type, const ffi::String& variable, size_t len = 0, \ - bool use_constructor = true) { \ - Declare(type, variable, len, use_constructor); \ - return *this; \ - } \ - template \ - Stack& declare_arg(const T& value) { \ - DeclareArg(value); \ - return *this; \ - } \ - Stack& class_def(const ffi::String& class_name) { \ - ClassDef(class_name); \ - return *this; \ - } \ - Stack& class_decorator(const ffi::String& decorator) { \ - ClassDecorator(decorator); \ - return *this; \ - } \ - Stack& class_start() { \ - ClassStart(); \ - return *this; \ - } \ - Stack& class_end() { \ - ClassEnd(); \ - return *this; \ - } \ - Stack& struct_start(const ffi::String& struct_name) { \ - StructStart(struct_name); \ - return *this; \ - } \ - Stack& struct_end() { \ - StructEnd(); \ - return *this; \ - } \ - Stack& func_def(const ffi::String& func_name, const ffi::String& ret_type = "") { \ - FuncDef(func_name, ret_type); \ - return *this; \ - } \ - Stack& func_arg(const ffi::String& arg, const ffi::String& annotation = "", \ - const ffi::String& value = "") { \ - FuncArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& func_decorator(const ffi::String& decorator) { \ - FuncDecorator(decorator); \ - return *this; \ - } \ - Stack& func_start() { \ - FuncStart(); \ - return *this; \ - } \ - Stack& func_end() { \ - FuncEnd(); \ - return *this; \ - } \ - template \ - Stack& func_end(const T& ret_val) { \ - FuncEnd(ret_val); \ - return *this; \ - } \ - Stack& func_call(const ffi::String& callee, ffi::Optional assign_to, \ - ffi::Optional caller = std::nullopt) { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& func_call(const ffi::String& callee, const ffi::String& assign_to = "", \ - const ffi::String& caller = "") { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& method_call(const ffi::String& callee, bool new_line = false) { \ - MethodCall(callee, new_line); \ - return *this; \ - } \ - Stack& inplace_start(const ffi::String& callee, ffi::Optional assign_to, \ - ffi::Optional caller = std::nullopt) { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_start(const ffi::String& callee, const ffi::String& assign_to = "", \ - const ffi::String& caller = "") { \ - InplaceStart(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& inplace_end() { \ - InplaceEnd(); \ - return *this; \ - } \ - Stack& constructor_def(const ffi::String& func_name) { \ - ConstructorDef(func_name); \ - return *this; \ - } \ - Stack& constructor_arg(const ffi::String& arg, const ffi::String& annotation = "", \ - const ffi::String& value = "") { \ - ConstructorArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& constructor_start() { \ - ConstructorStart(); \ - return *this; \ - } \ - Stack& constructor_end() { \ - ConstructorEnd(); \ - return *this; \ - } \ - Stack& lambda_def(const ffi::String& lambda_name) { \ - LambdaDef(lambda_name); \ - return *this; \ - } \ - Stack& lambda_arg(const ffi::String& arg, const ffi::String& annotation = "", \ - const ffi::String& value = "") { \ - LambdaArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& lambda_ref(const ffi::String& ref) { \ - LambdaRef(ref); \ - return *this; \ - } \ - Stack& lambda_start() { \ - LambdaStart(); \ - return *this; \ - } \ - Stack& lambda_end(const ffi::String& ret_val = "") { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& lambda_end(const ExprDoc& ret_val) { \ - LambdaEnd(ret_val); \ - return *this; \ - } \ - Stack& pop_nest(const ffi::String& key = "") { \ - PopNest(key); \ - return *this; \ - } \ - template \ - Stack& call_arg(T value, const ffi::String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const ExprDoc& value, const ffi::String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const ffi::Array& values) { \ - CallArg(values); \ - return *this; \ - } \ - Stack& cond_if(const ffi::String& predicate) { \ - ConditionIf(predicate); \ - return *this; \ - } \ - Stack& cond_else() { \ - ConditionElse(); \ - return *this; \ - } \ - Stack& cond_end() { \ - ConditionEnd(); \ - return *this; \ - } \ - template \ - Stack& for_start(const LT& lhs, const RT& rhs) { \ - ForStart(lhs, rhs); \ - return *this; \ - } \ - template \ - Stack& for_start(const ffi::String& lhs, const ST& start, const ET& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_start(const ffi::String& lhs, const ffi::String& start, const ffi::String& end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_end() { \ - ForEnd(); \ - return *this; \ - } \ - Stack& while_start(const ffi::String& predicate) { \ - WhileStart(predicate); \ - return *this; \ - } \ - Stack& while_end() { \ - WhileEnd(); \ - return *this; \ - } \ - Stack& switch_start(const ffi::String& predicate) { \ - SwitchStart(predicate); \ - return *this; \ - } \ - Stack& switch_case(const ffi::String& predicate = "") { \ - SwitchCase(predicate); \ - return *this; \ - } \ - Stack& switch_end() { \ - SwitchEnd(); \ - return *this; \ - } \ - Stack& block_start() { \ - BlockStart(); \ - return *this; \ - } \ - Stack& block_end(bool block_docs = true) { \ - BlockEnd(block_docs); \ - return *this; \ - } \ - Stack& scope_start(const ffi::String& scope_def = "", const ffi::String& scope_ref = "") { \ - ScopeStart(scope_def, scope_ref); \ - return *this; \ - } \ - Stack& scope_end() { \ - ScopeEnd(); \ - return *this; \ - } - -/*! - * \brief Stack Doc for common codegen - */ -class CodeStack : public BaseStack { - public: - /*! - * \brief The constructor of CodeStack - */ - CodeStack() : BaseStack() {} - - COMMON_WRAPPERS(CodeStack) -}; - -/*! - * \brief Stack Doc for codes - */ -template -class OpCodeStack : public BaseStack { - public: - /*! - * \brief The constructor of OpCodeStack - */ - OpCodeStack() : BaseStack() {} - - /*! \brief Set codegen*/ - void Config(OpCodeGenType* codegen, bool reset = true) { - codegen_ = codegen; - if (reset) { - Reset(); - } - } - - COMMON_WRAPPERS(OpCodeStack) - - /*! \brief Push op_call Doc*/ - OpCodeStack& op_call(const ffi::String& callee = "msc::auto", - const ffi::String& assign_to = "msc::auto") { - const ffi::String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; - const ffi::String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; - return func_call(v_callee, v_assign); - } - - /*! \brief Push op comment Doc*/ - OpCodeStack& op_comment(const ffi::String& comment_str = "msc::auto") { - const ffi::String& v_comment = (comment_str == "msc::auto" ? codegen_->Comment() : comment_str); - return comment(v_comment); - } - - /*! \brief Cache typed attribute as argument*/ - template - OpCodeStack& op_arg(const ffi::String& attr_key, - const ffi::String& key = "msc::auto") { - T attr_val; - if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; - return call_arg(attr_val, valid_key); - } - return *this; - } - - /*! \brief Cache str attribute as argument*/ - OpCodeStack& op_str_arg(const ffi::String& attr_key, - const ffi::String& key = "msc::auto") { - std::string attr_val; - if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; - return call_arg(DocUtils::ToStr(attr_val), valid_key); - } - return *this; - } - - /*! \brief Cache list attribute as argument*/ - template - OpCodeStack& op_list_arg(const ffi::String& attr_key, - const ffi::String& key = "msc::auto", - bool allow_empty = false) { - std::vector attr_val; - if (codegen_->node()->GetAttr(attr_key, &attr_val)) { - const ffi::String& valid_key = key == "msc::auto" ? attr_key : key; - return call_arg(DocUtils::ToList(attr_val, allow_empty), valid_key); - } - return *this; - } - - /*! \brief Cache input as argument*/ - OpCodeStack& op_input_arg(int idx = 0, const ffi::String& key = "") { - return call_arg(codegen_->IdxInput(idx, true), key); - } - - /*! \brief Cache inputs as argument*/ - OpCodeStack& op_inputs_arg(bool as_list = true, const ffi::String& key = "") { - ffi::Array inputs; - for (size_t i = 0; i < codegen_->node()->inputs.size(); i++) { - inputs.push_back(codegen_->IdxInput(i, true)); - } - if (as_list) { - return call_arg(DocUtils::ToList(inputs), key); - } else { - return call_arg(DocUtils::ToDocList(inputs)); - } - } - - /*! \brief Cache output as argument*/ - OpCodeStack& op_output_arg(int idx = 0, const ffi::String& key = "") { - return call_arg(codegen_->IdxOutput(idx), key); - } - - /*! \brief Cache weight as argument*/ - OpCodeStack& op_weight_arg(const ffi::String& wtype, const ffi::String& key = "") { - if (codegen_->node()->weights.count(wtype)) { - return call_arg(codegen_->IdxWeight(wtype, true), key); - } - return *this; - } - - /*! \brief Cache name as argument*/ - OpCodeStack& op_name_arg(const ffi::String& key = "msc::auto", - const ffi::String& name = "msc::auto") { - const ffi::String& valid_key = key == "msc::auto" ? "name" : key; - const ffi::String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; - return call_arg(DocUtils::ToStr(valid_name), valid_key); - return *this; - } - - OpCodeStack& op_dtype_arg(const DataType& dtype, const ffi::String& key = "") { - return call_arg(codegen_->DType(dtype), key); - } - - private: - OpCodeGenType* codegen_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CODE_STACK_H_ diff --git a/src/contrib/msc/core/codegen/codegen_json.cc b/src/contrib/msc/core/codegen/codegen_json.cc deleted file mode 100644 index 73854dd24039..000000000000 --- a/src/contrib/msc/core/codegen/codegen_json.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/codegen_json.cc - */ - -#include "codegen_json.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -std::vector MSCJSONSerializer::VisitExpr_(const CallNode* call_node) { - const auto& ref_node = graph_->FindNode(SpanUtils::GetAttr(call_node->span, "name")); - std::vector inputs; - for (const auto& arg : call_node->args) { - auto res = VisitExpr(arg); - inputs.insert(inputs.end(), res.begin(), res.end()); - } - auto node = - std::make_shared(ref_node->name, "kernel", inputs, ref_node->outputs.size()); - // add attributes - AddNodeAttr(node, "optype", ref_node->optype); - for (const auto& pair : ref_node->attrs) { - AddNodeAttr(node, pair.first, pair.second); - } - if (!global_options_set_) { - AddNodeAttr(node, "msc_global_options_num", ffi::String(std::to_string(options_.size()))); - for (const auto& pair : options_) { - AddNodeAttr(node, "msc_global_" + pair.first, pair.second); - } - global_options_set_ = true; - } - return AddNode(node, ffi::GetRef(call_node)); -} - -void MSCJSONSerializer::AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, - const ffi::String& value) { - node->SetAttr(std::string(key), value); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/codegen/codegen_json.h b/src/contrib/msc/core/codegen/codegen_json.h deleted file mode 100644 index d64717449fda..000000000000 --- a/src/contrib/msc/core/codegen/codegen_json.h +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/codegen_json.h - * \brief Basic JSONSerializer for MSC runnable BYOC. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ - -#include - -#include -#include -#include - -#include "../../../../relax/backend/contrib/codegen_json/codegen_json.h" -#include "../ir/graph.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; - -using JSONGraphNode = tvm::runtime::json::JSONGraphNode; -using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; -using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; -using JSONSerializer = backend::contrib::JSONSerializer; - -/*! - * \brief MSCCompileConfig defines config for all BYOC - */ -struct MSCCompileConfig { - std::string graph_json; - std::unordered_map options; - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("graph_json")); it != obj.end()) { - graph_json = std::string((*it).second.cast()); - } - for (const auto& kv : obj) { - std::string k = std::string(kv.first.cast()); - if (k != "graph_json") { - options[k] = std::string(kv.second.cast()); - } - } - } -}; - -class MSCJSONSerializer : public JSONSerializer { - public: - /*! - * \brief Constructor - * \param constant_names The names of all constants in the original module. - */ - explicit MSCJSONSerializer(const ffi::Map& constant_names, - const std::string& options) - : JSONSerializer(constant_names) { - namespace json = ::tvm::ffi::json; - MSCCompileConfig config; - config.Load(json::Parse(options).cast()); - TVM_FFI_ICHECK(config.graph_json.size() > 0) << "graph_json is needed to init MSCGraph"; - graph_ = MSCGraph(config.graph_json); - for (const auto& pair : config.options) { - options_.Set(pair.first, pair.second); - } - global_options_set_ = false; - } - - std::vector VisitExpr_(const CallNode* call_node) final; - - const ffi::String GetOption(const ffi::String& key) { - TVM_FFI_ICHECK(options_.count(key)) << "Can not find option " << key; - return options_[key]; - } - - const ffi::Map GetOptions() { return options_; } - - protected: - void AddNodeAttr(JSONGraphObjectPtr node, const ffi::String& key, const ffi::String& value); - - private: - MSCGraph graph_; - ffi::Map options_; - bool global_options_set_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_JSON_H_ diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc b/src/contrib/msc/core/codegen/codegen_utils.cc deleted file mode 100644 index 768c9f276e9e..000000000000 --- a/src/contrib/msc/core/codegen/codegen_utils.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/codegen_utils.cc - */ - -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::String CodeGenUtils::IdxNode(const MSCJoint& node, const ffi::String& prefix, - const ffi::String& suffix) { - return prefix + std::to_string(node->index) + suffix; -} - -const ffi::String CodeGenUtils::IdxOutput(const MSCJoint& node, const ffi::String& prefix, int idx, - const ffi::String& suffix) { - const auto& idx_node = IdxNode(node, prefix, suffix); - size_t output_size = node->outputs.size(); - if (output_size == 1 && node->optype != "tuple") { - return idx_node; - } - size_t v_index = CommonUtils::GetIndex(idx, output_size); - return idx_node + "[" + std::to_string(v_index) + "]"; -} - -const ffi::String CodeGenUtils::IdxInput(const MSCJoint& node, const ffi::String& prefix, int idx, - const ffi::String& suffix) { - const auto& pair = node->ProducerAndIdxOf(idx); - return IdxOutput(pair.first, prefix, pair.second, suffix); -} - -const ffi::String CodeGenUtils::IdxWeight(const MSCJoint& node, const ffi::String& wtype, - const ffi::String& suffix) { - return wtype + "_" + std::to_string(node->index) + suffix; -} - -const ffi::Array CodeGenUtils::GetPrims( - const MSCTensor& tensor, const ffi::Map& prims) { - ffi::Array dims; - if (tensor->prims.size() == 0) { - for (size_t i = 0; i < tensor->Ndim(); i++) { - dims.push_back(StringUtils::ToString(tensor->DimAt(i))); - } - return dims; - } - for (size_t i = 0; i < tensor->Ndim(); i++) { - const auto& prim = tensor->PrimAt(i); - dims.push_back(prims.count(prim) ? prims[prim] : prim); - } - return dims; -} - -const ffi::String CodeGenUtils::CommentNode(const MSCJoint& node, const ffi::String& prefix, - const ffi::Map& prims) { - ffi::String comment = node->name + "(" + node->optype + "): <"; - for (size_t i = 0; i < node->inputs.size(); i++) { - comment = comment + IdxInput(node, prefix, i) + (i == node->inputs.size() - 1 ? "> -> <" : ","); - } - for (size_t i = 0; i < node->outputs.size(); i++) { - const auto& t_output = node->OutputAt(i); - const auto& t_prims = GetPrims(t_output, prims); - comment = comment + IdxOutput(node, prefix, i) + "|" + StringUtils::Join(t_prims, ":"); - comment = comment + "|" + t_output->DTypeName(); - if (t_output->layout.defined()) { - comment = comment + "|" + t_output->layout->name; - } - comment = comment + (i == node->outputs.size() - 1 ? ">" : ", "); - } - return comment; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h deleted file mode 100644 index d6f5b5c37594..000000000000 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/codegen_utils.h - * \brief Common utilities for print. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_UTILS_H_ - -#include -#include - -#include -#include -#include - -#include "../ir/graph.h" -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -#define CODEGEN_CONFIG_MEMBERS \ - bool training{false}; \ - bool use_tools{false}; \ - bool use_plugin{false}; \ - bool need_test{true}; \ - std::string tools_scope{""}; \ - std::string tools_tag{"main"}; \ - std::string test_device{"cpu"}; \ - std::string prefix{"res_"}; \ - std::string baseline_folder{"baseline"}; \ - std::vector version{0, 0, 0}; - -#define CODEGEN_CONFIG_PARSE \ - namespace json = ::tvm::ffi::json; \ - if (auto it = obj.find(ffi::String("training")); it != obj.end()) { \ - training = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("use_tools")); it != obj.end()) { \ - use_tools = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("use_plugin")); it != obj.end()) { \ - use_plugin = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("need_test")); it != obj.end()) { \ - need_test = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("tools_scope")); it != obj.end()) { \ - tools_scope = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("tools_tag")); it != obj.end()) { \ - tools_tag = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("test_device")); it != obj.end()) { \ - test_device = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("prefix")); it != obj.end()) { \ - prefix = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("version")); it != obj.end()) { \ - auto arr = (*it).second.cast(); \ - version.clear(); \ - version.reserve(arr.size()); \ - for (const auto& elem : arr) { \ - version.push_back(static_cast(elem.cast())); \ - } \ - } \ - if (auto it = obj.find(ffi::String("baseline_folder")); it != obj.end()) { \ - baseline_folder = std::string((*it).second.cast()); \ - } - -#define DESCRIBE_PRIM_BINARY(OpType, Symbol, AsFunc) \ - if (prim->optype == OpType) { \ - if (AsFunc) { \ - return std::string(Symbol) + "(" + this->DescribePrim(prim->ParentAt(0)) + "," + \ - this->DescribePrim(prim->ParentAt(1)) + ")"; \ - } \ - return "(" + this->DescribePrim(prim->ParentAt(0)) + Symbol + \ - this->DescribePrim(prim->ParentAt(1)) + ")"; \ - } - -#define CODEGEN_MEMBERS \ - public: \ - virtual const ffi::String DType(const DataType& dtype) { \ - return runtime::DLDataTypeToString(dtype); \ - } \ - \ - protected: \ - const std::shared_ptr config() { return config_; } \ - const ffi::Map prims() { return prims_; } \ - const ffi::String IdxNodeBase(const MSCJoint& node) { \ - return helper_.IdxNodeBase(node, config()->prefix, ""); \ - } \ - const ffi::String IdxInputBase(const MSCJoint& node, int idx = 0, bool process = true) { \ - return helper_.IdxInputBase(node, config()->prefix, idx, "", process && config()->use_tools); \ - } \ - const ffi::String IdxOutputBase(const MSCJoint& node, int idx = 0, bool mark_exit = false) { \ - return helper_.IdxOutputBase(node, config()->prefix, idx, "", \ - mark_exit && config()->use_tools); \ - } \ - const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, \ - bool process = true) { \ - return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ - } \ - const ffi::Array GetPrims(const MSCTensor& tensor) { \ - return CodeGenUtils::GetPrims(tensor, prims_); \ - } \ - const ffi::String Comment(const MSCJoint& node) { \ - return helper_.Comment(node, config()->prefix, prims_); \ - } \ - int CompareVersion(size_t major, size_t minor, size_t patch) { \ - return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ - } \ - \ - private: \ - std::shared_ptr config_; \ - ffi::Map prims_; \ - HelperType helper_; - -/*! - * \brief Utils for CodeGen. - */ -class CodeGenUtils { - public: - /*! - * \brief Get indexed node string. - * \return The String. - */ - TVM_DLL static const ffi::String IdxNode(const MSCJoint& node, const ffi::String& prefix, - const ffi::String& suffix = ""); - - /*! - * \brief Get indexed output string. - * \return The String. - */ - TVM_DLL static const ffi::String IdxOutput(const MSCJoint& node, const ffi::String& prefix, - int idx = 0, const ffi::String& suffix = ""); - - /*! - * \brief Get indexed input string. - * \return The String. - */ - TVM_DLL static const ffi::String IdxInput(const MSCJoint& node, const ffi::String& prefix, - int idx = 0, const ffi::String& suffix = ""); - - /*! - * \brief Get indexed weight string. - * \return The String. - */ - TVM_DLL static const ffi::String IdxWeight(const MSCJoint& node, const ffi::String& wtype, - const ffi::String& suffix = ""); - - /*! - * \brief Infer prims of tensor. - * \return The prims. - */ - TVM_DLL static const ffi::Array GetPrims( - const MSCTensor& tensor, const ffi::Map& prims); - /*! - * \brief Get comment of a node. - * \return The String. - */ - TVM_DLL static const ffi::String CommentNode(const MSCJoint& node, const ffi::String& prefix, - const ffi::Map& prims); -}; - -/*! - * \brief Basic CodeGenHelper - */ -class BaseCodeGenHelper { - public: - const ffi::String GetSuffix(const MSCJoint& node, bool process = false) { - return process ? "c" + std::to_string(node->index) : ""; - } - - virtual const ffi::String IdxNodeBase(const MSCJoint& node, const ffi::String& prefix = "", - const ffi::String& suffix = "") { - return CodeGenUtils::IdxNode(node, prefix, suffix); - } - virtual const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", - int idx = 0, const ffi::String& suffix = "", - bool process = false) { - const auto& pair = node->ProducerAndIdxOf(idx); - size_t output_size = pair.first->outputs.size(); - if (process && (output_size > 1 || pair.first->optype == "tuple")) { - return CodeGenUtils::IdxNode(pair.first, prefix, suffix) + "_" + std::to_string(pair.second); - } - return CodeGenUtils::IdxInput(node, prefix, idx, suffix + GetSuffix(node, process)); - } - virtual const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", - int idx = 0, const ffi::String& suffix = "", - bool mark_exit = false) { - if (mark_exit) { - if (node->outputs.size() > 1 || node->optype == "tuple") { - return CodeGenUtils::IdxNode(node, prefix, suffix) + "_" + std::to_string(idx) + "_exit"; - } - return CodeGenUtils::IdxOutput(node, prefix, idx, suffix + "_exit"); - } - return CodeGenUtils::IdxOutput(node, prefix, idx, suffix); - } - virtual const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, - const ffi::String& suffix = "", bool process = false) { - return CodeGenUtils::IdxWeight(node, wtype, suffix + GetSuffix(node, process)); - } - virtual const ffi::String Comment( - const MSCJoint& node, const ffi::String& prefix = "", - const ffi::Map& prims = ffi::Map()) { - return CodeGenUtils::CommentNode(node, prefix, prims); - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h deleted file mode 100644 index ee9fb490c8b3..000000000000 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/cpp_codegen.h - * \brief CPP codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ - -#include - -#include -#include - -#include "../printer/cpp_printer.h" -#include "base_codegen.h" -#include "code_stack.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -template -class CppCodeGen : public BaseCodeGen { - public: - /*! - * \brief The constructor of PyCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit CppCodeGen(const MSCGraph& graph, const std::string& config = "") - : BaseCodeGen(graph, config) { - for (const auto& output : graph->GetOutputs()) { - graph_outputs_.insert(output); - } - } - - /*! \brief Stack the docs for the class declare*/ - virtual void CodeGenClassDeclare() = 0; - - /*! \brief Stack the docs for the class define*/ - virtual void CodeGenClassDefine() = 0; - - /*! \brief Stack the docs for the main func*/ - virtual void CodeGenMain() = 0; - - /*! \brief Stack the docs for the class define*/ - virtual void CodeGenCmake() = 0; - - /*! \brief Get sources*/ - virtual const ffi::Map GetSources( - const std::string& print_options = "") { - ffi::Map sources; - auto add_source = [&print_options, &sources, this](const ffi::String& file) { - CppPrinter printer(print_options); - for (const auto& d : this->stack_.GetDocs()) { - printer.Append(d); - } - sources.Set(file, printer.GetString()); - this->stack_.Reset(); - }; - // class declare - CodeGenClassDeclare(); - add_source(this->graph()->name + ".h"); - // class define - CodeGenClassDefine(); - add_source(this->graph()->name + ".cc"); - // main func - CodeGenMain(); - add_source("main.cc"); - // cmakelists - CodeGenCmake(); - add_source("CMakeLists.txt"); - return sources; - } - - protected: - /*! \brief Describe the prim*/ - virtual const ffi::String DescribePrim(const MSCPrim& prim) { - // binary ops - DESCRIBE_PRIM_BINARY("Min", "std::min", true) - DESCRIBE_PRIM_BINARY("Max", "std::max", true) - // special - if (prim->optype == "if_then_else") { - return "(" + this->DescribePrim(prim->ParentAt(0)) + "?" + - this->DescribePrim(prim->ParentAt(1)) + ":" + this->DescribePrim(prim->ParentAt(2)) + - ")"; - } - return BaseCodeGen::DescribePrim(prim); - } - - /*! \brief Stack the docs for the node*/ - virtual void CodeGenNode(const MSCJoint& node, bool use_tools) { - this->stack_.comment(this->Comment(node)); - // process inputs and weights by tools - if (use_tools) { - const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_tensor"); - for (size_t i = 0; i < node->inputs.size(); i++) { - const auto& input = node->InputAt(i); - ffi::Any lines = pf(GetTensorCtx(input), input->name, node->name, - this->config()->tools_scope, this->config()->tools_tag); - for (const auto& l : lines.cast>()) { - this->stack_.line(l); - } - } - for (const auto& pair : node->weights) { - ffi::Any lines = pf(GetTensorCtx(pair.second), pair.second->name, node->name, - this->config()->tools_scope, this->config()->tools_tag); - for (const auto& l : lines.cast>()) { - this->stack_.line(l); - } - } - } - for (const auto& d : this->GetOpCodes(node)) { - this->stack_.line(d); - } - // process graph outputs by tools - if (use_tools) { - const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_tensor"); - for (size_t i = 0; i < node->outputs.size(); i++) { - int index = static_cast(i); - if (graph_outputs_.count(node->OutputAt(index))) { - const auto& output = node->OutputAt(index); - ffi::Any lines = pf(GetTensorCtx(output), output->name, node->name, - this->config()->tools_scope, this->config()->tools_tag); - for (const auto& l : lines.cast>()) { - this->stack_.line(l); - } - } - } - } - } - - /*! \brief Get the tensor context for codegen_tensor*/ - virtual const ffi::Map GetTensorCtx(const MSCTensor& tensor) { - ffi::Map tensor_ctx; - MSCJoint producer; - if (this->graph()->weight_holders.count(tensor->name)) { - producer = this->graph()->FindProducer(tensor); - for (const auto& pair : producer->weights) { - if (pair.second == tensor) { - tensor_ctx.Set("tensor", this->IdxWeightBase(producer, pair.first)); - break; - } - } - TVM_FFI_ICHECK(tensor_ctx.count("tensor")) - << "Can not find weight " << tensor << " from " << producer; - } else { - const auto& pair = this->graph()->FindProducerAndIdx(tensor); - producer = pair.first; - tensor_ctx.Set("tensor", this->IdxOutputBase(pair.first, pair.second)); - } - tensor_ctx.Set("producer", this->IdxNodeBase(producer)); - return tensor_ctx; - } - - /*! \brief Get the step context for codegen_step*/ - virtual const ffi::Map GetStepCtx() { - ffi::Map step_ctx; - std::string version = ""; - for (size_t i = 0; i < this->config()->version.size(); i++) { - version += std::to_string(this->config()->version[i]) + - (i < this->config()->version.size() - 1 ? "." : ""); - } - step_ctx.Set("version", version); - return step_ctx; - } - - void StartNamespace() { - this->stack_.line("namespace tvm {").line("namespace contrib {").line("namespace msc {").line(); - } - - void EndNamespace() { - this->stack_.line() - .line("} // namespace tvm") - .line("} // namespace contrib") - .line("} // namespace msc") - .line(); - } - - private: - std::set graph_outputs_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_CPP_CODEGEN_H_ diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h deleted file mode 100644 index e51adda8e460..000000000000 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/codegen/py_codegen.h - * \brief Python codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_CORE_CODEGEN_PY_CODEGEN_H_ -#define TVM_CONTRIB_MSC_CORE_CODEGEN_PY_CODEGEN_H_ - -#include - -#include -#include - -#include "../printer/python_printer.h" -#include "base_codegen.h" -#include "code_stack.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -template -class PyCodeGen : public BaseCodeGen { - public: - /*! - * \brief The constructor of PyCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit PyCodeGen(const MSCGraph& graph, const std::string& config = "") - : BaseCodeGen(graph, config) { - for (const auto& output : graph->GetOutputs()) { - graph_outputs_.insert(output); - } - } - - /*! \brief Stack the docs for the script*/ - virtual void CodeGenScript() { - CodeGenHeader(); - this->stack_.line().comment("Define the helpers"); - CodeGenHelper(); - this->stack_.line().comment("Define the graph"); - CodeGenGraph(); - if (this->config()->need_test) { - this->stack_.line().comment("Define the test"); - CodeGenTest(); - } - } - - /*! \brief Get sources*/ - virtual const ffi::Map GetSources( - const std::string& print_options = "") { - ffi::Map sources; - PythonPrinter printer(print_options); - CodeGenScript(); - for (const auto& d : this->stack_.GetDocs()) { - printer.Append(d); - } - sources.Set(this->graph()->name + ".py", printer.GetString()); - return sources; - } - - protected: - /*! \brief Describe the prim*/ - virtual const ffi::String DescribePrim(const MSCPrim& prim) { - // binary ops - DESCRIBE_PRIM_BINARY("Min", "min", true) - DESCRIBE_PRIM_BINARY("Max", "max", true) - // special - if (prim->optype == "if_then_else") { - return "(" + this->DescribePrim(prim->ParentAt(1)) + " if " + - this->DescribePrim(prim->ParentAt(0)) + " else " + - this->DescribePrim(prim->ParentAt(2)) + ")"; - } - return BaseCodeGen::DescribePrim(prim); - } - - /*! \brief Stack the docs for the header*/ - virtual void CodeGenHeader() { - this->stack_.line("import os") - .line("import numpy as np") - .line("from typing import List, Dict, Any") - .line("import tvm"); - if (this->config()->use_tools) { - this->stack_.line("from tvm.contrib.msc.core import tools as msc_tools"); - } - this->stack_.line("from tvm.contrib.msc.core import utils as msc_utils"); - } - - /*! \brief Stack the docs for the helpers*/ - virtual void CodeGenHelper() { - if (this->config()->need_test) { - this->stack_.func_def("load_data", "np.ndarray") - .func_arg("name", "str") - .func_arg("shape", "List[int]") - .func_arg("dtype", "str") - .func_start() - .func_call("os.path.join", "path") - .call_arg(DocUtils::ToStr(this->config()->baseline_folder)) - .call_arg("name + \".bin\"") - .cond_if("os.path.isfile(path)") - .func_call("np.fromfile", "data") - .call_arg("path") - .call_arg("dtype", "dtype") - .method_call("reshape") - .call_arg("shape") - .cond_else() - .func_call("np.ones", "data") - .call_arg("(shape)") - .method_call("astype") - .call_arg("dtype") - .cond_end() - .func_end("data"); - } - } - - /*! \brief Stack the docs for the test*/ - void CodeGenTest() { - this->stack_.cond_if("__name__ == \"__main__\"") - .comment("Prepare test datas") - .assign("inputs", "{}") - .assign("golden", "{}"); - for (const auto& i : this->graph()->input_names) { - const auto& input = this->graph()->FindTensor(i); - this->stack_ - .func_call("load_data", DocUtils::ToIndex("inputs", DocUtils::ToStr(input->alias))) - .call_arg(DocUtils::ToStr(input->alias)) - .call_arg(DocUtils::ToList(input->shape, true)) - .call_arg(DocUtils::ToStr(runtime::DLDataTypeToString(input->dtype))); - } - for (const auto& o : this->graph()->output_names) { - const auto& output = this->graph()->FindTensor(o); - this->stack_ - .func_call("load_data", DocUtils::ToIndex("golden", DocUtils::ToStr(output->alias))) - .call_arg(DocUtils::ToStr(output->alias)) - .call_arg(DocUtils::ToList(output->shape, true)) - .call_arg(DocUtils::ToStr(runtime::DLDataTypeToString(output->dtype))); - } - this->stack_.comment("Build and inference the graph"); - CodeGenInference(); - this->stack_.func_call("msc_utils.compare_arrays") - .call_arg("golden") - .call_arg("outputs") - .call_arg(DocUtils::ToStr("detail"), "verbose") - .cond_end(); - } - - /*! \brief Stack the docs for the node*/ - virtual void CodeGenNode(const MSCJoint& node, bool use_tools) { - this->stack_.comment(this->Comment(node)); - // process inputs and weights by tools - if (use_tools) { - for (size_t i = 0; i < node->inputs.size(); i++) { - const auto& input = node->InputAt(i); - this->stack_.func_call("msc_tools.process_tensor", this->IdxInputBase(node, i, true)) - .call_arg(this->IdxInputBase(node, i, false)) - .call_arg(DocUtils::ToStr(input->name)) - .call_arg(DocUtils::ToStr(node->name)) - .call_arg(DocUtils::ToStr(this->config()->tools_scope)) - .call_arg(DocUtils::ToStr(this->config()->tools_tag)); - } - for (const auto& pair : node->weights) { - this->stack_ - .func_call("msc_tools.process_tensor", this->IdxWeightBase(node, pair.first, true)) - .call_arg(this->IdxWeightBase(node, pair.first, false)) - .call_arg(DocUtils::ToStr(pair.second->name)) - .call_arg(DocUtils::ToStr(node->name)) - .call_arg(DocUtils::ToStr(this->config()->tools_scope)) - .call_arg(DocUtils::ToStr(this->config()->tools_tag)); - } - } - for (const auto& d : this->GetOpCodes(node)) { - this->stack_.line(d); - } - // process graph outputs by tools - if (use_tools) { - for (size_t i = 0; i < node->outputs.size(); i++) { - int index = static_cast(i); - if (graph_outputs_.count(node->OutputAt(index))) { - this->stack_.func_call("msc_tools.process_tensor", this->IdxOutputBase(node, index, true)) - .call_arg(this->IdxOutputBase(node, index, false)) - .call_arg(DocUtils::ToStr(node->OutputAt(index)->name)) - .call_arg(DocUtils::ToStr("exit")) - .call_arg(DocUtils::ToStr(this->config()->tools_scope)) - .call_arg(DocUtils::ToStr(this->config()->tools_tag)); - } - } - } - } - - /*! \brief Stack the docs for the graph*/ - virtual void CodeGenGraph() = 0; - - /*! \brief Stack the docs for the graph inference*/ - virtual void CodeGenInference() = 0; - - /*! \brief Get tensor type of the framework*/ - virtual const ffi::String TensorType() const { return "np.ndarray"; } - - private: - std::set graph_outputs_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_CODEGEN_PY_CODEGEN_H_ diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc deleted file mode 100644 index c8ad5309fcff..000000000000 --- a/src/contrib/msc/core/ir/graph.cc +++ /dev/null @@ -1,1655 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/graph.cc - */ - -#include "graph.h" - -#include - -#include -#include -#include -#include - -#include "../printer/prototxt_printer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -MSCTensor::MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, - const ffi::Array& shape, const ffi::String& alias, - const ffi::Array& prims) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - n->alias = std::move(alias); - n->dtype = std::move(dtype); - n->shape = std::move(shape); - n->layout = tvm::tir::Layout(layout); - n->prims = prims; - data_ = std::move(n); -} - -MSCTensor::MSCTensor(const JsonMSCTensor& j_tensor) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_tensor); - data_ = std::move(n); -} - -MSCTensor::MSCTensor(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonMSCTensor MSCTensorNode::ToJson() const { - JsonMSCTensor j_tensor; - j_tensor.name = name; - j_tensor.alias = alias; - j_tensor.dtype = runtime::DLDataTypeToString(dtype); - if (layout.defined()) { - j_tensor.layout = layout.name(); - } - for (const auto& s : shape) { - j_tensor.shape.push_back(s->value); - } - for (const auto& p : prims) { - j_tensor.prims.push_back(p); - } - return j_tensor; -} - -void MSCTensorNode::FromJson(const JsonMSCTensor& j_tensor) { - name = j_tensor.name; - alias = j_tensor.alias; - dtype = DataType(ffi::StringToDLDataType(j_tensor.dtype)); - if (j_tensor.layout.size() > 0) { - layout = tvm::tir::Layout(j_tensor.layout); - } - for (const auto& s : j_tensor.shape) { - shape.push_back(s); - } - for (const auto& p : j_tensor.prims) { - prims.push_back(p); - } -} - -void MSCTensorNode::FromJson(const std::string& json_str) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonMSCTensor j_tensor; - j_tensor.Load(parsed.cast()); - FromJson(j_tensor); -} - -size_t MSCTensorNode::Ndim() const { return shape.size(); } - -const Integer MSCTensorNode::DimAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, Ndim()); - return shape[v_index]; -} - -const Integer MSCTensorNode::DimAt(const ffi::String& axis) const { - auto index = layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); - return DimAt(index); -} - -const ffi::String MSCTensorNode::PrimAt(int index) const { - if (prims.size() == 0) { - return ""; - } - return prims[CommonUtils::GetIndex(index, Ndim())]; -} - -const ffi::String MSCTensorNode::PrimAt(const ffi::String& axis) const { - return PrimAt(layout.IndexOf(tvm::tir::LayoutAxis::Get(axis))); -} - -int32_t MSCTensorNode::LayoutOf(const ffi::String& axis) const { - return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); -} - -const Integer MSCTensorNode::GetSize() const { - Integer size = Integer(1); - for (const auto& s : shape) { - size *= s; - } - return size; -} - -const ffi::String MSCTensorNode::DTypeName() const { return runtime::DLDataTypeToString(dtype); } - -size_t BaseJointNode::AddChild(const BaseJoint& child) const { - for (size_t i = 0; i < children.size(); i++) { - if (Downcast(children[i])->name == child->name) { - return i; - } - } - children.push_back(child); - return children.size() - 1; -} - -const BaseJoint BaseJointNode::ParentAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, parents.size()); - return Downcast(parents[v_index]); -} - -const BaseJoint BaseJointNode::ChildAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, children.size()); - return Downcast(children[v_index]); -} - -bool BaseJointNode::HasAttr(const ffi::String& key) const { return attrs.count(key); } - -bool BaseJointNode::GetAttr(const ffi::String& key, std::string* val) const { - if (attrs.count(key) && attrs[key].size() > 0) { - *val = attrs[key]; - return true; - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, int* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - int pos = val_str.find(","); - if (pos > 0) { - return false; - } - try { - *val = std::stoi(val_str); - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, int64_t* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - try { - *val = std::stoi(val_str); - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, float* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - try { - *val = std::atof(val_str.c_str()); - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, bool* val) const { - int val_int; - if (GetAttr(key, &val_int)) { - *val = (val_int != 0); - return true; - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - int pos = val_str.find(","); - if (pos < 0) { - return false; - } - try { - for (const auto& s : StringUtils::Split(val_str, ",")) { - (*val).push_back(std::string(s)); - } - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - int pos = val_str.find(","); - if (pos < 0) { - return false; - } - try { - for (const auto& s : StringUtils::Split(val_str, ",")) { - (*val).push_back(std::stoi(std::string(s))); - } - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - try { - int pos = val_str.find(","); - if (pos < 0) { - return false; - } - for (const auto& s : StringUtils::Split(val_str, ",")) { - (*val).push_back(std::stol(std::string(s))); - } - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} -bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - int pos = val_str.find(","); - if (pos < 0) { - return false; - } - try { - for (const auto& s : StringUtils::Split(val_str, ",")) { - (*val).push_back(std::atof(std::string(s).c_str())); - } - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -bool BaseJointNode::GetAttr(const ffi::String& key, std::vector* val) const { - std::string val_str; - if (GetAttr(key, &val_str)) { - int pos = val_str.find(","); - if (pos < 0) { - return false; - } - try { - for (const auto& s : StringUtils::Split(val_str, ",")) { - (*val).push_back(std::stoi(s) != 0); - } - return true; - } catch (const std::exception&) { - return false; - } - } - return false; -} - -MSCJoint::MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& optype, const ffi::Map& attrs, - const ffi::Array& scope, - const std::vector>& inputs, - const ffi::Array& outputs, - const ffi::Map& weights) { - ObjectPtr n = ffi::make_object(); - n->index = index; - n->name = std::move(name); - n->shared_ref = std::move(shared_ref); - n->optype = std::move(optype); - n->attrs = std::move(attrs); - n->scope = std::move(scope); - ffi::Array parents; - ffi::Array> array_inputs; - ffi::Array added_parents; - for (const auto& pair : inputs) { - // const auto& parent=Downcast(pair.first); - const auto& p_name = pair.first->name; - int p_idx = -1; - for (size_t i = 0; i < added_parents.size(); i++) { - if (added_parents[i] == p_name) { - p_idx = i; - break; - } - } - if (p_idx == -1) { - parents.push_back(pair.first); - added_parents.push_back(p_name); - p_idx = added_parents.size() - 1; - } - ffi::Array input{Integer(p_idx), Integer(pair.second)}; - array_inputs.push_back(input); - } - n->parents = std::move(parents); - n->inputs = std::move(array_inputs); - n->outputs = std::move(outputs); - n->weights = std::move(weights); - data_ = std::move(n); -} - -MSCJoint::MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_joint, nodes); - data_ = std::move(n); -} - -MSCJoint::MSCJoint(const std::string& json_str, const ffi::Map& nodes) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str, nodes); - data_ = std::move(n); -} - -const MSCJoint MSCJoint::Clone(const MSCJoint& node, - const std::vector>& inputs) { - return MSCJoint(node->index, node->name, node->shared_ref, node->optype, node->attrs, node->scope, - inputs, node->outputs, node->weights); -} - -const JsonMSCJoint MSCJointNode::ToJson() const { - JsonMSCJoint j_joint; - j_joint.index = index; - j_joint.name = name; - j_joint.shared_ref = shared_ref; - j_joint.optype = optype; - for (const auto& pair : attrs) { - j_joint.attrs[pair.first] = pair.second; - } - for (const auto& s : scope) { - j_joint.scope.push_back(s); - } - for (const auto& p : parents) { - j_joint.parents.push_back(Downcast(p)->name); - } - for (const auto& i : GetInputs()) { - j_joint.inputs.push_back(i->name); - } - for (const auto& o : GetOutputs()) { - j_joint.outputs.push_back(o->ToJson()); - } - for (const auto& pair : weights) { - j_joint.weights[pair.first] = pair.second->ToJson(); - } - return j_joint; -} - -void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, - const ffi::Map& nodes) { - index = j_joint.index; - name = j_joint.name; - shared_ref = j_joint.shared_ref; - optype = j_joint.optype; - for (const auto& pair : j_joint.attrs) { - attrs.Set(pair.first, pair.second); - } - for (const auto& s : j_joint.scope) { - scope.push_back(s); - } - for (const auto& p_name : j_joint.parents) { - TVM_FFI_ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; - parents.push_back(nodes[p_name]); - } - for (const auto& in_name : j_joint.inputs) { - ffi::String producer, index_str; - std::tie(producer, index_str) = StringUtils::SplitOnce(in_name, ":"); - int p_idx = -1; - for (size_t i = 0; i < parents.size(); i++) { - if (ParentAt(i)->name == producer) { - p_idx = i; - break; - } - } - TVM_FFI_ICHECK(p_idx >= 0) << "Can not find parent for " << in_name; - ffi::Array input{Integer(p_idx), Integer(std::stol(index_str))}; - inputs.push_back(input); - } - for (const auto& o : j_joint.outputs) { - outputs.push_back(MSCTensor(o)); - } - for (const auto& pair : j_joint.weights) { - weights.Set(pair.first, MSCTensor(pair.second)); - } -} - -void MSCJointNode::FromJson(const std::string& json_str, - const ffi::Map& nodes) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonMSCJoint j_joint; - j_joint.Load(parsed.cast()); - FromJson(j_joint, nodes); -} - -const MSCTensor MSCJointNode::InputAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, inputs.size()); - const auto& p_idx = inputs[v_index][0]; - const auto& out_idx = inputs[v_index][1]; - return ParentAt(p_idx->value)->OutputAt(out_idx->value); -} - -const ffi::Array MSCJointNode::GetInputs() const { - ffi::Array t_inputs; - for (size_t i = 0; i < inputs.size(); i++) { - t_inputs.push_back(InputAt(i)); - } - return t_inputs; -} - -const MSCTensor MSCJointNode::OutputAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, outputs.size()); - return outputs[v_index]; -} - -const ffi::Array MSCJointNode::GetOutputs() const { - ffi::Array t_outputs; - for (size_t i = 0; i < outputs.size(); i++) { - t_outputs.push_back(OutputAt(i)); - } - return t_outputs; -} - -const MSCTensor MSCJointNode::WeightAt(const ffi::String& wtype) const { - TVM_FFI_ICHECK(weights.count(wtype)) << "Can not find " << wtype << " from weights"; - return weights[wtype]; -} - -const MSCJoint MSCJointNode::ParentAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, parents.size()); - return Downcast(parents[v_index]); -} - -const MSCJoint MSCJointNode::ChildAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, children.size()); - return Downcast(children[v_index]); -} - -const MSCJoint MSCJointNode::ProducerOf(int index) const { - const auto& pair = ProducerAndIdxOf(index); - return pair.first; -} - -const MSCJoint MSCJointNode::ProducerOf(const ffi::String& input_name) const { - const auto& pair = ProducerAndIdxOf(input_name); - return pair.first; -} - -const MSCJoint MSCJointNode::ProducerOf(const MSCTensor& input) const { - return ProducerOf(input->name); -} - -const std::pair MSCJointNode::ProducerAndIdxOf(int index) const { - size_t v_index = CommonUtils::GetIndex(index, inputs.size()); - const auto& p_idx = inputs[v_index][0]; - return std::make_pair(ParentAt(p_idx->value), inputs[v_index][1]->value); -} - -const std::pair MSCJointNode::ProducerAndIdxOf(const ffi::String& name) const { - for (size_t i = 0; i < inputs.size(); i++) { - if (InputAt(i)->name == name) { - return ProducerAndIdxOf(i); - } - } - TVM_FFI_THROW(InternalError) << "Can not find producer of " << name; -} - -const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor& input) const { - return ProducerAndIdxOf(input->name); -} - -MSCPrim::MSCPrim(int index, const ffi::String& name, const ffi::String& optype, - const ffi::Array& parents, - const ffi::Map& attrs) { - ObjectPtr n = ffi::make_object(); - n->index = index; - n->name = std::move(name); - n->optype = std::move(optype); - n->attrs = std::move(attrs); - for (const auto& p : parents) { - n->parents.push_back(p); - } - data_ = std::move(n); -} - -MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_prim, prims); - data_ = std::move(n); -} - -MSCPrim::MSCPrim(const std::string& json_str, const ffi::Map& prims) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str, prims); - data_ = std::move(n); -} - -const JsonMSCPrim MSCPrimNode::ToJson() const { - JsonMSCPrim j_prim; - j_prim.index = index; - j_prim.name = name; - j_prim.optype = optype; - for (const auto& pair : attrs) { - j_prim.attrs[pair.first] = pair.second; - } - for (const auto& p : parents) { - j_prim.parents.push_back(Downcast(p)->name); - } - return j_prim; -} - -void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, - const ffi::Map& prims) { - index = j_prim.index; - name = j_prim.name; - optype = j_prim.optype; - for (const auto& pair : j_prim.attrs) { - attrs.Set(pair.first, pair.second); - } - for (const auto& p_name : j_prim.parents) { - TVM_FFI_ICHECK(prims.count(p_name)) << "Can not find parent " << p_name; - parents.push_back(prims[p_name]); - } -} - -void MSCPrimNode::FromJson(const std::string& json_str, - const ffi::Map& prims) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonMSCPrim j_prim; - j_prim.Load(parsed.cast()); - FromJson(j_prim, prims); -} - -const MSCPrim MSCPrimNode::ParentAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, parents.size()); - return Downcast(parents[v_index]); -} - -const MSCPrim MSCPrimNode::ChildAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, children.size()); - return Downcast(children[v_index]); -} - -WeightJoint::WeightJoint(int index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& weight_type, const MSCTensor& weight, - const ffi::Array parents, - const ffi::Map& attrs, - const ffi::Array& friends) { - ObjectPtr n = ffi::make_object(); - n->index = index; - n->name = std::move(name); - n->shared_ref = std::move(shared_ref); - n->weight_type = std::move(weight_type); - n->attrs = std::move(attrs); - n->weight = std::move(weight); - for (const auto& p : parents) { - n->parents.push_back(p); - } - n->friends = std::move(friends); - data_ = std::move(n); -} - -WeightJoint::WeightJoint(const JsonWeightJoint& j_joint, - const ffi::Map& nodes) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_joint, nodes); - data_ = std::move(n); -} - -WeightJoint::WeightJoint(const std::string& json_str, - const ffi::Map& nodes) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str, nodes); - data_ = std::move(n); -} - -const JsonWeightJoint WeightJointNode::ToJson() const { - JsonWeightJoint j_joint; - j_joint.index = index; - j_joint.name = name; - j_joint.shared_ref = shared_ref; - j_joint.weight_type = weight_type; - j_joint.weight = weight->ToJson(); - for (const auto& pair : attrs) { - j_joint.attrs[pair.first] = pair.second; - } - for (const auto& p : parents) { - j_joint.parents.push_back(Downcast(p)->name); - } - for (const auto& f : friends) { - j_joint.friends.push_back(Downcast(f)->name); - } - - return j_joint; -} - -void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, - const ffi::Map& nodes) { - index = j_joint.index; - name = j_joint.name; - shared_ref = j_joint.shared_ref; - weight_type = j_joint.weight_type; - weight = MSCTensor(j_joint.weight); - for (const auto& pair : j_joint.attrs) { - attrs.Set(pair.first, pair.second); - } - for (const auto& p_name : j_joint.parents) { - TVM_FFI_ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; - parents.push_back(nodes[p_name]); - } -} - -void WeightJointNode::FromJson(const std::string& json_str, - const ffi::Map& nodes) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonWeightJoint j_joint; - j_joint.Load(parsed.cast()); - FromJson(j_joint, nodes); -} - -const WeightJoint WeightJointNode::ParentAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, parents.size()); - return Downcast(parents[v_index]); -} - -const WeightJoint WeightJointNode::ChildAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, children.size()); - return Downcast(children[v_index]); -} - -const bool BaseGraphNode::HasNode(const ffi::String& name) const { - return nodes.count(name) ? true : false; -} - -MSCGraph::MSCGraph(const ffi::String& name, const ffi::Array& nodes, - const ffi::Array& input_names, - const ffi::Array& output_names, const ffi::Array& prims) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - for (const auto& node : nodes) { - n->node_names.push_back(node->name); - n->nodes.Set(node->name, node); - } - n->input_names = std::move(input_names); - n->output_names = std::move(output_names); - for (const auto& prim : prims) { - n->prim_names.push_back(prim->name); - n->prims.Set(prim->name, prim); - } - n->AnalysisGraph(); - data_ = std::move(n); -} - -MSCGraph::MSCGraph(const JsonMSCGraph& j_graph) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_graph); - data_ = std::move(n); -} - -MSCGraph::MSCGraph(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonMSCGraph MSCGraphNode::ToJson() const { - JsonMSCGraph j_graph; - j_graph.name = name; - for (const auto& i : input_names) { - j_graph.inputs.push_back(i); - } - for (const auto& o : output_names) { - j_graph.outputs.push_back(o); - } - for (const auto& n : node_names) { - const auto& node = FindNode(n); - j_graph.nodes.push_back(node->ToJson()); - } - for (const auto& n : prim_names) { - const auto& prim = FindPrim(n); - j_graph.prims.push_back(prim->ToJson()); - } - return j_graph; -} - -void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { - name = j_graph.name; - for (const auto& i : j_graph.inputs) { - input_names.push_back(i); - } - for (const auto& o : j_graph.outputs) { - output_names.push_back(o); - } - ffi::Map loaded_nodes; - for (const auto& n : j_graph.nodes) { - const auto& node = MSCJoint(n, loaded_nodes); - loaded_nodes.Set(node->name, node); - for (const auto& p : node->parents) { - Downcast(p)->AddChild(node); - } - node_names.push_back(node->name); - nodes.Set(node->name, node); - } - ffi::Map loaded_prims; - for (const auto& n : j_graph.prims) { - const auto& prim = MSCPrim(n, loaded_prims); - loaded_prims.Set(prim->name, prim); - for (const auto& p : prim->parents) { - Downcast(p)->AddChild(prim); - } - prim_names.push_back(prim->name); - prims.Set(prim->name, prim); - } - AnalysisGraph(); -} - -void MSCGraphNode::FromJson(const std::string& json_str) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonMSCGraph j_graph; - j_graph.Load(parsed.cast()); - FromJson(j_graph); -} - -const ffi::String MSCGraphNode::ToPrototxt() const { - PrototxtPrinter printer; - printer.Append(ffi::Map{{"name", name}}); - for (const auto& n : node_names) { - const auto& node = FindNode(n); - // define layer - std::vector> layer; - layer.push_back(std::make_pair("name", node->name)); - layer.push_back(std::make_pair("type", StringUtils::Replace(node->optype, ".", "_"))); - layer.push_back(std::make_pair("top", node->name)); - for (const auto& p : node->parents) { - layer.push_back(std::make_pair("bottom", Downcast(p)->name)); - } - // define layer param - ffi::Map param; - param.Set("idx", Integer(node->index)); - for (size_t i = 0; i < node->inputs.size(); i++) { - param.Set("input_" + std::to_string(i), node->InputAt(i)); - } - for (size_t i = 0; i < node->outputs.size(); i++) { - param.Set("output_" + std::to_string(i), node->OutputAt(i)); - } - for (const auto& pair : node->weights) { - param.Set("param_" + pair.first, pair.second); - } - for (const auto& pair : node->attrs) { - param.Set(pair.first, pair.second); - } - layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); - // Append the layer Map - printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); - } - return printer.GetString(); -} - -const MSCJoint MSCGraphNode::FindNode(const ffi::String& name) const { - TVM_FFI_ICHECK(nodes.count(name)) << "Can not find node " << name; - return Downcast(nodes[name]); -} - -const MSCPrim MSCGraphNode::FindPrim(const ffi::String& name) const { - TVM_FFI_ICHECK(prims.count(name)) << "Can not find prim " << name; - return prims[name]; -} - -const MSCTensor MSCGraphNode::InputAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, input_names.size()); - return FindTensor(input_names[v_index]); -} - -const ffi::Array MSCGraphNode::GetInputs() const { - ffi::Array t_inputs; - for (size_t i = 0; i < input_names.size(); i++) { - t_inputs.push_back(InputAt(i)); - } - return t_inputs; -} - -const MSCTensor MSCGraphNode::OutputAt(int index) const { - size_t v_index = CommonUtils::GetIndex(index, output_names.size()); - return FindTensor(output_names[v_index]); -} - -const ffi::Array MSCGraphNode::GetOutputs() const { - ffi::Array t_outputs; - for (size_t i = 0; i < output_names.size(); i++) { - t_outputs.push_back(OutputAt(i)); - } - return t_outputs; -} - -const ffi::Array MSCGraphNode::GetEntries() const { - ffi::Array entries; - for (size_t i = 0; i < input_names.size(); i++) { - entries.push_back(FindProducer(input_names[i])); - } - return entries; -} - -const ffi::Array MSCGraphNode::GetExits() const { - ffi::Array exits; - std::set setted_exits; - for (size_t i = 0; i < output_names.size(); i++) { - const auto& exit = FindProducer(output_names[i]); - if (setted_exits.count(exit->name)) { - continue; - } - exits.push_back(exit); - setted_exits.insert(exit->name); - } - return exits; -} - -const bool MSCGraphNode::HasTensor(const ffi::String& name) const { - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - if (weight_holders.count(tensor_name)) { - return true; - } - ffi::String host, index; - std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); - return nodes.count(host) > 0 ? true : false; -} - -const MSCTensor MSCGraphNode::FindTensor(const ffi::String& name) const { - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - if (weight_holders.count(tensor_name)) { - const auto& node = FindNode(weight_holders[tensor_name][0]); - for (const auto& pair : node->weights) { - if (pair.second->name == tensor_name) { - return pair.second; - } - } - TVM_FFI_THROW(InternalError) << "Can not find weight " << name << " from " << node; - } - const auto& pair = FindProducerAndIdx(name); - return pair.first->OutputAt(pair.second); -} - -const MSCJoint MSCGraphNode::FindProducer(const ffi::String& name) const { - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - if (weight_holders.count(tensor_name)) { - return FindNode(weight_holders[tensor_name][0]); - } - const auto& pair = FindProducerAndIdx(name); - return pair.first; -} - -const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const { - return FindProducer(tensor->name); -} - -const std::pair MSCGraphNode::FindProducerAndIdx(const ffi::String& name) const { - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - TVM_FFI_ICHECK(!weight_holders.count(tensor_name)) - << "Weight " << name << " has no producer with index"; - ffi::String host, index; - std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); - if (index.size() == 0) { - const auto& node = FindNode(host); - TVM_FFI_ICHECK(node->optype == "constant") - << "Tensor without index should be constant, get " << node; - return std::make_pair(node, 0); - } - return std::make_pair(FindNode(host), std::stoi(index)); -} - -const std::pair MSCGraphNode::FindProducerAndIdx(const MSCTensor& tensor) const { - return FindProducerAndIdx(tensor->name); -} - -const ffi::Array MSCGraphNode::FindConsumers(const ffi::String& name) const { - ffi::Array consumers; - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - if (weight_holders.count(tensor_name)) { - for (const auto& h : weight_holders[tensor_name]) { - consumers.push_back(FindNode(h)); - } - } else { - const auto& producer = FindProducer(name); - for (const auto& c : producer->children) { - consumers.push_back(Downcast(c)); - } - } - return consumers; -} - -const ffi::Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const { - return FindConsumers(tensor->name); -} - -const std::vector> MSCGraphNode::FindConsumersAndIndices( - const ffi::String& name) const { - const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - TVM_FFI_ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; - std::vector> consumers; - for (const auto& c : FindConsumers(name)) { - bool find_tensor = false; - for (size_t i = 0; i < c->inputs.size(); i++) { - if (c->InputAt(i)->name == name) { - consumers.push_back(std::make_pair(c, i)); - find_tensor = true; - break; - } - } - TVM_FFI_ICHECK(find_tensor) << "Can not find tensor " << name << " from " << c; - } - return consumers; -} - -const std::vector> MSCGraphNode::FindConsumersAndIndices( - const MSCTensor& tensor) const { - return FindConsumersAndIndices(tensor->name); -} - -void MSCGraphNode::AnalysisGraph() { - // Add children - for (const auto& n : node_names) { - const auto& node = FindNode(n); - for (const auto& p : node->parents) { - Downcast(p)->AddChild(node); - } - } - // Check inputs and outputs - for (const auto& i : input_names) { - const auto& input = FindTensor(i); - if (input->alias.size() > 0) { - tensor_alias.Set(input->alias, input->name); - } - } - for (const auto& o : output_names) { - FindTensor(o); - } - // Set tensor alias and weight_holders - for (const auto& n : node_names) { - const auto& node = FindNode(n); - for (const auto& o : node->outputs) { - if (o->alias.size() > 0) { - tensor_alias.Set(o->alias, o->name); - } - } - for (const auto& pair : node->weights) { - const auto& w_name = pair.second->name; - if (weight_holders.count(w_name)) { - ffi::Array holders = weight_holders[w_name]; - holders.push_back(n); - weight_holders.Set(w_name, holders); - } else { - weight_holders.Set(w_name, ffi::Array({n})); - if (pair.second->alias.size() > 0) { - tensor_alias.Set(pair.second->alias, pair.second->name); - } - } - } - } -} - -WeightGraph::WeightGraph(const MSCGraph& graph, - const ffi::Map>& main_wtypes, - const ffi::Map& relation_wtypes) { - ObjectPtr n = ffi::make_object(); - n->name = graph->name + "_weights"; - n->Build(graph, main_wtypes, relation_wtypes); - data_ = std::move(n); -} - -WeightGraph::WeightGraph(const JsonWeightGraph& j_graph) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_graph); - data_ = std::move(n); -} - -WeightGraph::WeightGraph(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -void WeightGraphNode::Build(const MSCGraph& graph, - const ffi::Map>& main_wtypes, - const ffi::Map& relation_wtypes) { - auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) { - return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index; - }; - - auto find_parents = [this, &main_wtypes, &relation_wtypes, &sort_nodes](const MSCJoint& node) { - std::vector parents; - std::queue frontier; - std::set explored; - for (const auto& p : node->parents) { - frontier.push(Downcast(p)); - } - while (!frontier.empty()) { - const auto& current = frontier.front(); - if (explored.count(current)) { - frontier.pop(); - continue; - } - explored.insert(current); - if (main_wtypes.count(current->optype)) { - for (const auto& t_type : main_wtypes[current->optype]) { - if (current->weights.count(t_type)) { - parents.push_back(FindNode(current->WeightAt(t_type)->name)); - } - } - } else if (relation_wtypes.count(current->optype)) { - parents.push_back(FindNode(current->OutputAt(0)->name)); - } else { - for (const auto& p : current->parents) { - const auto& new_parent = Downcast(p); - if (!explored.count(new_parent)) { - frontier.push(new_parent); - } - } - } - frontier.pop(); - } - ffi::Array parents_array; - if (parents.size() > 1) { - std::sort(parents.begin(), parents.end(), sort_nodes); - } - for (const auto& p : parents) { - parents_array.push_back(p); - } - return parents_array; - }; - - for (const auto& n : graph->node_names) { - const auto& node = graph->FindNode(n); - if (node->shared_ref.size() > 0) { - continue; - } - if (main_wtypes.count(node->optype) || relation_wtypes.count(node->optype) || - node->weights.size() > 0) { - const auto& w_parents = find_parents(node); - bool bind_friends = true; - if (relation_wtypes.count(node->optype) && relation_wtypes[node->optype] == "multi_inputs") { - bind_friends = false; - } - if (w_parents.size() > 1 && bind_friends) { - for (const auto& p : w_parents) { - Downcast(p)->friends = w_parents; - } - } - if (main_wtypes.count(node->optype)) { - for (const auto& wtype : main_wtypes[node->optype]) { - if (node->weights.count(wtype)) { - const auto& weight = node->WeightAt(wtype); - ffi::Map attrs; - attrs.Set("producer_type", node->optype); - attrs.Set("weight_strategy", "main"); - const auto& w_node = - WeightJoint(node_names.size(), weight->name, "", wtype, weight, w_parents, attrs); - for (const auto& p : w_parents) { - p->AddChild(w_node); - } - nodes.Set(weight->name, w_node); - node_names.push_back(weight->name); - } - } - const BaseJoint& head = FindNode(node_names[node_names.size() - 1]); - for (const auto& pair : node->weights) { - if (!nodes.count(pair.second->name)) { - ffi::Map attrs; - attrs.Set("producer_type", node->optype); - attrs.Set("weight_strategy", "follow"); - const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, - pair.second, {head}, attrs); - head->AddChild(w_node); - nodes.Set(pair.second->name, w_node); - node_names.push_back(pair.second->name); - } - } - } else if (relation_wtypes.count(node->optype)) { - const auto& tensor = node->OutputAt(0); - ffi::Map attrs; - attrs.Set("producer_type", node->optype); - if (node->optype == "reshape") { - // TODO(archermmt): check non-passby reshape - attrs.Set("weight_strategy", "passby"); - } else { - attrs.Set("weight_strategy", relation_wtypes[node->optype]); - } - const auto& t_node = - WeightJoint(node_names.size(), tensor->name, "", "output", tensor, w_parents, attrs); - for (const auto& p : w_parents) { - p->AddChild(t_node); - } - nodes.Set(tensor->name, t_node); - node_names.push_back(tensor->name); - } else if (node->weights.size() > 0) { - for (const auto& pair : node->weights) { - if (!nodes.count(pair.second->name)) { - ffi::Map attrs; - attrs.Set("producer_type", node->optype); - attrs.Set("weight_strategy", "follow"); - const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first, - pair.second, w_parents, attrs); - for (const auto& p : w_parents) { - p->AddChild(w_node); - } - nodes.Set(pair.second->name, w_node); - node_names.push_back(pair.second->name); - } - } - } - } - } -} - -const WeightJoint WeightGraphNode::FindNode(const ffi::String& name) const { - TVM_FFI_ICHECK(nodes.count(name)) << "Can not find node " << name; - return Downcast(nodes[name]); -} - -const JsonWeightGraph WeightGraphNode::ToJson() const { - JsonWeightGraph j_graph; - j_graph.name = name; - for (const auto& n : node_names) { - const auto& node = FindNode(n); - j_graph.nodes.push_back(node->ToJson()); - } - return j_graph; -} - -void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { - name = j_graph.name; - ffi::Map loaded_nodes; - for (const auto& n : j_graph.nodes) { - const auto& node = WeightJoint(n, loaded_nodes); - loaded_nodes.Set(node->name, node); - for (const auto& p : node->parents) { - Downcast(p)->AddChild(node); - } - node_names.push_back(node->name); - nodes.Set(node->name, node); - } - // set friends - for (const auto& j_joint : j_graph.nodes) { - const auto& node = Downcast(nodes[j_joint.name]); - for (const auto& f_name : j_joint.friends) { - TVM_FFI_ICHECK(nodes.count(f_name)) << "Can not find friend " << f_name; - node->friends.push_back(nodes[f_name]); - } - } -} - -void WeightGraphNode::FromJson(const std::string& json_str) { - namespace json = ::tvm::ffi::json; - auto parsed = json::Parse(json_str); - JsonWeightGraph j_graph; - j_graph.Load(parsed.cast()); - FromJson(j_graph); -} - -const ffi::String WeightGraphNode::ToPrototxt() const { - PrototxtPrinter printer; - printer.Append(ffi::Map{{"name", name}}); - for (const auto& n : node_names) { - const auto& node = FindNode(n); - // define layer - std::vector> layer; - layer.push_back(std::make_pair("name", node->name)); - layer.push_back(std::make_pair("type", node->weight_type)); - layer.push_back(std::make_pair("top", node->name)); - for (const auto& p : node->parents) { - layer.push_back(std::make_pair("bottom", Downcast(p)->name)); - } - // define layer param - ffi::Map param; - param.Set("idx", Integer(node->index)); - param.Set("weight", node->weight); - for (size_t i = 0; i < node->friends.size(); i++) { - param.Set("friend_" + std::to_string(i), Downcast(node->friends[i])); - } - for (const auto& pair : node->attrs) { - param.Set(pair.first, pair.second); - } - layer.push_back(std::make_pair("layer_param", PrototxtPrinter::ToDictDoc(param))); - // Append the layer Map - printer.Append(ffi::Map{{"layer", PrototxtPrinter::ToDictDoc(layer)}}); - } - return printer.GetString(); -} - -MSCGraph PruneWeights(const MSCGraph& graph, - const ffi::Map& pruned_tensors) { - ffi::Array nodes; - std::unordered_map> inputs_map; - for (const auto& name : graph->node_names) { - const auto& node = graph->FindNode(name); - // define inputs - std::vector> inputs; - for (const auto& input : node->GetInputs()) { - TVM_FFI_ICHECK(inputs_map.count(input->name)) << "Can not find input " << input; - inputs.push_back(inputs_map[input->name]); - } - // define outputs - ffi::Array outputs; - for (const auto& out : node->outputs) { - const auto& output = pruned_tensors.count(out->name) ? pruned_tensors[out->name] : out; - outputs.push_back(output); - } - // define weights - ffi::Map weights; - for (const auto& pair : node->weights) { - const auto& weight = - pruned_tensors.count(pair.second->name) ? pruned_tensors[pair.second->name] : pair.second; - weights.Set(pair.first, weight); - } - // define attributes - ffi::Map attrs = node->attrs; - if (node->optype == "reshape" && attrs.count("shape") && - pruned_tensors.count(node->OutputAt(0)->name)) { - const auto& new_shape = pruned_tensors[node->OutputAt(0)->name]->shape; - attrs.Set("shape", StringUtils::ToString(new_shape)); - } - // create new node - const auto& new_node = MSCJoint(static_cast(nodes.size()), node->name, node->shared_ref, - node->optype, attrs, node->scope, inputs, outputs, weights); - nodes.push_back(new_node); - for (size_t i = 0; i < new_node->outputs.size(); i++) { - inputs_map[new_node->OutputAt(i)->name] = std::make_pair(new_node, i); - } - for (const auto& p : new_node->parents) { - Downcast(p)->AddChild(new_node); - } - } - ffi::Array prims; - for (const auto& name : graph->prim_names) { - prims.push_back(graph->FindPrim(name)); - } - return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names, prims); -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* tensor = static_cast(node.get()); - p->PrintIndent(); - p->stream << tensor->name; - if (tensor->alias.size() > 0) { - p->stream << "(" << tensor->alias << ")"; - } - p->stream << "<"; - for (size_t i = 0; i < tensor->Ndim(); i++) { - const auto& prim = tensor->PrimAt(i); - p->stream << (prim.size() > 0 ? prim : StringUtils::ToString(tensor->shape[i])) - << (i == tensor->Ndim() - 1 ? "|" : ","); - } - p->stream << tensor->dtype; - if (tensor->layout.defined()) { - p->stream << "|" << tensor->layout.name(); - } - p->stream << ">"; - }); - -#define MSC_NODE_BASE_HEAD(Stream, Joint, Type) \ - Stream << Type << "_" << Joint->index << " " << Joint->name; \ - if (Joint->shared_ref.size() > 0) { \ - Stream << "(M: " << Joint->shared_ref << ")"; \ - } \ - Stream << " parents.size() > 0) { \ - for (size_t i = 0; i < Joint->parents.size(); i++) { \ - Stream << Joint->ParentAt(i)->name << (i == Joint->parents.size() - 1 ? "" : ","); \ - } \ - } \ - Stream << "| CHILDERN: "; \ - if (Joint->children.size() > 0) { \ - for (size_t i = 0; i < Joint->children.size(); i++) { \ - Stream << Joint->ChildAt(i)->name << (i == Joint->children.size() - 1 ? "" : ","); \ - } \ - } \ - Stream << ">\n"; - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* joint = static_cast(node.get()); - p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint, "N"); - if (joint->inputs.size() > 0) { - p->stream << " IN: "; - for (size_t i = 0; i < joint->inputs.size(); i++) { - p->stream << joint->InputAt(i) << (i == joint->inputs.size() - 1 ? "\n" : ","); - } - } - p->stream << " OUT: "; - for (size_t i = 0; i < joint->outputs.size(); i++) { - p->stream << joint->OutputAt(i) << (i == joint->outputs.size() - 1 ? "\n" : ","); - } - p->stream << " OPTYPE: " << joint->optype << "\n"; - if (joint->scope.size() > 0) { - p->stream << " SCOPE: "; - for (size_t i = 0; i < joint->scope.size(); i++) { - p->stream << joint->scope[i] << (i == joint->scope.size() - 1 ? "\n" : "."); - } - } - if (joint->attrs.size() > 0) { - p->stream << " ATTRS: "; - for (const auto& pair : joint->attrs) { - p->stream << pair.first << "=" << pair.second << " "; - } - p->stream << "\n"; - } - if (joint->weights.size() > 0) { - p->stream << " WEIGHTS: "; - for (const auto& pair : joint->weights) { - p->stream << "\n " << pair.first << ": " << pair.second; - } - p->stream << "\n"; - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* prim = static_cast(node.get()); - p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, prim, "P"); - p->stream << " OPTYPE: " << prim->optype; - if (prim->attrs.size() > 0) { - p->stream << "\n ATTRS: "; - for (const auto& pair : prim->attrs) { - p->stream << pair.first << "=" << pair.second << " "; - } - } - p->stream << "\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* joint = static_cast(node.get()); - p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint, "W"); - if (joint->friends.size() > 0) { - p->stream << " FRIENDS: "; - for (size_t i = 0; i < joint->friends.size(); i++) { - p->stream << Downcast(joint->friends[i])->name - << (i == joint->friends.size() - 1 ? "\n" : ","); - } - } - p->stream << " WEIGHT_TYPE: " << joint->weight_type; - p->stream << "\n WEIGHT: " << joint->weight; - if (joint->attrs.size() > 0) { - p->stream << "\n ATTRS: "; - for (const auto& pair : joint->attrs) { - p->stream << pair.first << "=" << pair.second << " "; - } - } - p->stream << "\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* graph = static_cast(node.get()); - p->PrintIndent(); - p->stream << graph->name << "\n"; - for (const auto& n : graph->node_names) { - p->stream << graph->FindNode(n) << "\n"; - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* graph = static_cast(node.get()); - p->PrintIndent(); - p->stream << graph->name << " input_names.size(); i++) { - p->stream << graph->input_names[i] << (i == graph->input_names.size() - 1 ? "| " : ","); - } - p->stream << "OUTPUTS: "; - for (size_t i = 0; i < graph->output_names.size(); i++) { - p->stream << graph->output_names[i] << (i == graph->output_names.size() - 1 ? ">\n" : ","); - } - for (const auto& n : graph->prim_names) { - p->stream << graph->FindPrim(n) << "\n"; - } - for (const auto& n : graph->node_names) { - p->stream << graph->FindNode(n) << "\n"; - } - }); - -TVM_FFI_STATIC_INIT_BLOCK() { - MSCTensorNode::RegisterReflection(); - BaseJointNode::RegisterReflection(); - MSCJointNode::RegisterReflection(); - MSCPrimNode::RegisterReflection(); - WeightJointNode::RegisterReflection(); - BaseGraphNode::RegisterReflection(); - MSCGraphNode::RegisterReflection(); - WeightGraphNode::RegisterReflection(); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.MSCTensor", - [](const ffi::String& name, const DataType& dtype, const ffi::String& layout, - const ffi::Array& shape, const ffi::String& alias, - const ffi::Array& prims) -> MSCTensor { - return MSCTensor(name, dtype, layout, shape, alias, prims); - }) - .def("msc.core.MSCTensorToJson", - [](const MSCTensor& tensor) -> ffi::String { - namespace json = ::tvm::ffi::json; - const auto& tensor_json = tensor->ToJson(); - std::ostringstream os; - os << std::string(json::Stringify(tensor_json.SaveToJSON())); - return os.str(); - }) - .def("msc.core.MSCTensorFromJson", - [](const ffi::String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }) - .def("msc.core.MSCJoint", - [](Integer index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& optype, const ffi::Map& attrs, - const ffi::Array& scope, const ffi::Array& parents, - const ffi::Array out_indices, const ffi::Array& outputs, - const ffi::Map& weights) -> MSCJoint { - std::vector> inputs; - for (size_t i = 0; i < parents.size(); i++) { - inputs.push_back(std::make_pair(parents[i], out_indices[i]->value)); - } - return MSCJoint(index->value, name, shared_ref, optype, attrs, scope, inputs, outputs, - weights); - }) - .def("msc.core.MSCPrim", - [](Integer index, const ffi::String& name, const ffi::String& optype, - const ffi::Map& attrs, - const ffi::Array& parents) -> MSCPrim { - ffi::Array b_parents; - for (const auto& p : parents) { - b_parents.push_back(p); - } - return MSCPrim(index->value, name, optype, b_parents, attrs); - }) - .def("msc.core.WeightJoint", - [](Integer index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& weight_type, const MSCTensor& weight, - const ffi::Array parents, const ffi::Map& attrs, - const ffi::Array& friends) -> WeightJoint { - ffi::Array b_parents, b_friends; - for (const auto& p : parents) { - b_parents.push_back(p); - } - for (const auto& f : friends) { - b_friends.push_back(f); - } - return WeightJoint(index->value, name, shared_ref, weight_type, weight, b_parents, - attrs, b_friends); - }) - .def("msc.core.WeightJointSetAttr", - [](const WeightJoint& node, const ffi::String& key, const ffi::String& value) { - node->attrs.Set(key, value); - }) - .def("msc.core.MSCGraph", - [](const ffi::String& name, const ffi::Array& nodes, - const ffi::Array& input_names, - const ffi::Array& output_names, - const ffi::Array& prims) -> MSCGraph { - return MSCGraph(name, nodes, input_names, output_names, prims); - }) - .def("msc.core.WeightGraph", - [](const MSCGraph& graph, - const ffi::Map>& main_wtypes, - const ffi::Map& relation_wtypes) -> WeightGraph { - return WeightGraph(graph, main_wtypes, relation_wtypes); - }); -} - -// MSC Graph APIS -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.MSCGraphHasNode", - [](const MSCGraph& graph, const ffi::String& name) -> Bool { - return Bool(graph->HasNode(name)); - }) - .def("msc.core.MSCGraphFindNode", - [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { - return graph->FindNode(name); - }) - .def("msc.core.MSCGraphFindPrim", - [](const MSCGraph& graph, const ffi::String& name) -> MSCPrim { - return graph->FindPrim(name); - }) - .def("msc.core.MSCGraphHasTensor", - [](const MSCGraph& graph, const ffi::String& name) -> Bool { - return Bool(graph->HasTensor(name)); - }) - .def("msc.core.MSCGraphFindTensor", - [](const MSCGraph& graph, const ffi::String& name) -> MSCTensor { - return graph->FindTensor(name); - }) - .def("msc.core.MSCGraphSetTensorAlias", - [](const MSCGraph& graph, const MSCTensor& tensor, const ffi::String& alias) { - tensor->alias = alias; - graph->tensor_alias.Set(alias, tensor->name); - }) - .def("msc.core.MSCGraphFindProducer", - [](const MSCGraph& graph, const ffi::String& name) -> MSCJoint { - return graph->FindProducer(name); - }) - .def("msc.core.MSCGraphFindConsumers", - [](const MSCGraph& graph, const ffi::String& name) -> ffi::Array { - return graph->FindConsumers(name); - }) - .def("msc.core.MSCGraphInputAt", - [](const MSCGraph& graph, int index) -> MSCTensor { return graph->InputAt(index); }) - .def("msc.core.MSCGraphOutputAt", - [](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }) - .def("msc.core.MSCGraphGetInputs", - [](const MSCGraph& graph) -> ffi::Array { return graph->GetInputs(); }) - .def("msc.core.MSCGraphGetOutputs", - [](const MSCGraph& graph) -> ffi::Array { return graph->GetOutputs(); }) - .def("msc.core.MSCGraphToJson", - [](const MSCGraph& graph) -> ffi::String { - namespace json = ::tvm::ffi::json; - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - os << std::string(json::Stringify(graph_json.SaveToJSON())); - return os.str(); - }) - .def("msc.core.MSCGraphFromJson", - [](const ffi::String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }) - .def("msc.core.MSCGraphToPrototxt", - [](const MSCGraph& graph) -> ffi::String { return graph->ToPrototxt(); }); -} - -// Weight Graph APIS -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.WeightGraphHasNode", - [](const WeightGraph& graph, const ffi::String& name) -> Bool { - return Bool(graph->HasNode(name)); - }) - .def("msc.core.WeightGraphFindNode", - [](const WeightGraph& graph, const ffi::String& name) -> WeightJoint { - return graph->FindNode(name); - }) - .def("msc.core.WeightGraphToJson", - [](const WeightGraph& graph) -> ffi::String { - namespace json = ::tvm::ffi::json; - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - os << std::string(json::Stringify(graph_json.SaveToJSON())); - return os.str(); - }) - .def("msc.core.WeightGraphFromJson", - [](const ffi::String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }) - .def("msc.core.WeightGraphToPrototxt", - [](const WeightGraph& graph) -> ffi::String { return graph->ToPrototxt(); }) - .def("msc.core.MSCJointInputAt", - [](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }) - .def("msc.core.MSCJointOutputAt", - [](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }) - .def("msc.core.MSCJointWeightAt", - [](const MSCJoint& node, const ffi::String& wtype) -> MSCTensor { - return node->WeightAt(wtype); - }) - .def("msc.core.MSCJointGetInputs", - [](const MSCJoint& node) -> ffi::Array { return node->GetInputs(); }) - .def("msc.core.MSCJointGetOutputs", - [](const MSCJoint& node) -> ffi::Array { return node->GetOutputs(); }) - .def("msc.core.MSCJointGetWeights", - [](const MSCJoint& node) -> ffi::Map { return node->weights; }) - .def("msc.core.MSCJointHasAttr", - [](const MSCJoint& node, const ffi::String& key) -> Bool { - return Bool(node->HasAttr(key)); - }) - .def("msc.core.MSCJointGetAttrs", - [](const MSCJoint& node) -> ffi::Map { return node->attrs; }) - .def("msc.core.WeightJointHasAttr", - [](const WeightJoint& node, const ffi::String& key) -> Bool { - return Bool(node->HasAttr(key)); - }) - .def( - "msc.core.WeightJointGetAttrs", - [](const WeightJoint& node) -> ffi::Map { return node->attrs; }) - .def("msc.core.MSCTensorDTypeName", - [](const MSCTensor& tensor) -> ffi::String { return tensor->DTypeName(); }) - .def("msc.core.MSCTensorDimAt", - [](const MSCTensor& tensor, const ffi::String& axis) -> Integer { - return tensor->DimAt(axis); - }) - .def("msc.core.MSCTensorGetSize", - [](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }) - .def("msc.core.MSCTensorSetAlias", - [](const MSCTensor& tensor, const ffi::String& alias) { tensor->alias = alias; }) - .def("msc.core.PruneWeights", - [](const MSCGraph& graph, const ffi::Map& pruned_tensors) - -> MSCGraph { return PruneWeights(graph, pruned_tensors); }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h deleted file mode 100644 index a649cadda0af..000000000000 --- a/src/contrib/msc/core/ir/graph.h +++ /dev/null @@ -1,1154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/graph.h - * \brief Core MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_CORE_IR_GRAPH_H_ -#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_H_ - -#include -#include -#include - -#include -#include -#include -#include - -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief Json serialize and deserialize for MSCTensor. - * MSCTensor is edge in MSCGraph with name, dtype and shape - */ -struct JsonMSCTensor { - std::string name; - std::string alias; - std::string dtype; - std::string layout; - std::vector shape; - std::vector prims; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("alias"), ffi::String(alias)); - obj.Set(ffi::String("dtype"), ffi::String(dtype)); - obj.Set(ffi::String("layout"), ffi::String(layout)); - { - ffi::json::Array arr; - for (const auto& v : shape) { - arr.push_back(static_cast(v)); - } - obj.Set(ffi::String("shape"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& s : prims) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("prims"), std::move(arr)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("alias")); it != obj.end()) { - alias = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("dtype")); it != obj.end()) { - dtype = std::string((*it).second.cast()); - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("layout")); it != obj.end()) { - layout = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("shape")); it != obj.end()) { - auto arr = (*it).second.cast(); - shape.clear(); - shape.reserve(arr.size()); - for (const auto& elem : arr) { - shape.push_back(elem.cast()); - } - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("prims")); it != obj.end()) { - auto arr = (*it).second.cast(); - prims.clear(); - prims.reserve(arr.size()); - for (const auto& elem : arr) { - prims.push_back(std::string(elem.cast())); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, dtype and shape should be given"; - } -}; - -/*! - * \brief Json serialize and deserialize for MSCJoint. - * MSCJoint is node in MSCGraph with name, optype and attrbutes. - * MSCJoint has MSCTensors as inputs, outputs and weights. - */ -struct JsonMSCJoint { - size_t index; - std::string name; - std::string shared_ref; - std::string optype; - std::vector scope; - std::vector parents; - std::vector inputs; - std::vector outputs; - std::unordered_map attrs; - std::unordered_map weights; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("index"), static_cast(index)); - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("shared_ref"), ffi::String(shared_ref)); - obj.Set(ffi::String("optype"), ffi::String(optype)); - { - ffi::json::Array arr; - for (const auto& s : parents) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("parents"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& s : inputs) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("inputs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : outputs) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("outputs"), std::move(arr)); - } - { - ffi::json::Object inner; - for (const auto& kv : attrs) { - inner.Set(ffi::String(kv.first), ffi::String(kv.second)); - } - obj.Set(ffi::String("attrs"), std::move(inner)); - } - { - ffi::json::Object inner; - for (const auto& kv : weights) { - inner.Set(ffi::String(kv.first), kv.second.SaveToJSON()); - } - obj.Set(ffi::String("weights"), std::move(inner)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("index")); it != obj.end()) { - index = static_cast((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("shared_ref")); it != obj.end()) { - shared_ref = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("optype")); it != obj.end()) { - optype = std::string((*it).second.cast()); - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("parents")); it != obj.end()) { - auto arr = (*it).second.cast(); - parents.clear(); - parents.reserve(arr.size()); - for (const auto& elem : arr) { - parents.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("inputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - inputs.clear(); - inputs.reserve(arr.size()); - for (const auto& elem : arr) { - inputs.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("outputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - outputs.clear(); - outputs.reserve(arr.size()); - for (const auto& elem : arr) { - JsonMSCTensor item; - item.Load(elem.cast()); - outputs.push_back(std::move(item)); - } - bitmask |= 8; - } - if (auto it = obj.find(ffi::String("attrs")); it != obj.end()) { - auto inner = (*it).second.cast(); - attrs.clear(); - for (const auto& kv : inner) { - attrs[std::string(kv.first.cast())] = - std::string(kv.second.cast()); - } - } - if (auto it = obj.find(ffi::String("weights")); it != obj.end()) { - auto inner = (*it).second.cast(); - weights.clear(); - for (const auto& kv : inner) { - JsonMSCTensor item; - item.Load(kv.second.cast()); - weights[std::string(kv.first.cast())] = std::move(item); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "index, name, optype and outputs should be given"; - } -}; - -/*! - * \brief Json serialize and deserialize for MSCPrim. - * MSCPrim is node in MSCGraph with name, op and attrbutes. - */ -struct JsonMSCPrim { - size_t index; - std::string name; - std::string optype; - std::vector parents; - std::unordered_map attrs; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("index"), static_cast(index)); - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("optype"), ffi::String(optype)); - { - ffi::json::Array arr; - for (const auto& s : parents) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("parents"), std::move(arr)); - } - { - ffi::json::Object inner; - for (const auto& kv : attrs) { - inner.Set(ffi::String(kv.first), ffi::String(kv.second)); - } - obj.Set(ffi::String("attrs"), std::move(inner)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("index")); it != obj.end()) { - index = static_cast((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("optype")); it != obj.end()) { - optype = std::string((*it).second.cast()); - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("parents")); it != obj.end()) { - auto arr = (*it).second.cast(); - parents.clear(); - parents.reserve(arr.size()); - for (const auto& elem : arr) { - parents.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("attrs")); it != obj.end()) { - auto inner = (*it).second.cast(); - attrs.clear(); - for (const auto& kv : inner) { - attrs[std::string(kv.first.cast())] = - std::string(kv.second.cast()); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "index, name and optype should be given"; - } -}; - -/*! - * \brief Json serialize and deserialize for WeightJoint. - * WeightJoint is node in WeightGraph with name, wtype and attrbutes. - * WeightJoint has MSCTensor as weight. - */ -struct JsonWeightJoint { - size_t index; - std::string name; - std::string shared_ref; - std::string weight_type; - JsonMSCTensor weight; - std::vector parents; - std::vector friends; - std::unordered_map attrs; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("index"), static_cast(index)); - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("shared_ref"), ffi::String(shared_ref)); - obj.Set(ffi::String("weight_type"), ffi::String(weight_type)); - obj.Set(ffi::String("weight"), weight.SaveToJSON()); - { - ffi::json::Array arr; - for (const auto& s : parents) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("parents"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& s : friends) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("friends"), std::move(arr)); - } - { - ffi::json::Object inner; - for (const auto& kv : attrs) { - inner.Set(ffi::String(kv.first), ffi::String(kv.second)); - } - obj.Set(ffi::String("attrs"), std::move(inner)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("index")); it != obj.end()) { - index = static_cast((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("shared_ref")); it != obj.end()) { - shared_ref = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("weight_type")); it != obj.end()) { - weight_type = std::string((*it).second.cast()); - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("weight")); it != obj.end()) { - weight.Load((*it).second.cast()); - bitmask |= 8; - } - if (auto it = obj.find(ffi::String("parents")); it != obj.end()) { - auto arr = (*it).second.cast(); - parents.clear(); - parents.reserve(arr.size()); - for (const auto& elem : arr) { - parents.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("friends")); it != obj.end()) { - auto arr = (*it).second.cast(); - friends.clear(); - friends.reserve(arr.size()); - for (const auto& elem : arr) { - friends.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("attrs")); it != obj.end()) { - auto inner = (*it).second.cast(); - attrs.clear(); - for (const auto& kv : inner) { - attrs[std::string(kv.first.cast())] = - std::string(kv.second.cast()); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) - << "index, name, weight_type and weight should be given"; - } -}; - -/*! - * \brief Json serialize and deserialize for MSCGraph. - * MSCGraph is core of MSC. - * MSCGraph contains MSCJoints as nodes and MSCTensors as edges. - */ -struct JsonMSCGraph { - std::string name; - std::vector inputs; - std::vector outputs; - std::vector nodes; - std::vector prims; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - { - ffi::json::Array arr; - for (const auto& s : inputs) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("inputs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& s : outputs) { - arr.push_back(ffi::String(s)); - } - obj.Set(ffi::String("outputs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : nodes) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("nodes"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : prims) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("prims"), std::move(arr)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("inputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - inputs.clear(); - inputs.reserve(arr.size()); - for (const auto& elem : arr) { - inputs.push_back(std::string(elem.cast())); - } - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("outputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - outputs.clear(); - outputs.reserve(arr.size()); - for (const auto& elem : arr) { - outputs.push_back(std::string(elem.cast())); - } - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("nodes")); it != obj.end()) { - auto arr = (*it).second.cast(); - nodes.clear(); - nodes.reserve(arr.size()); - for (const auto& elem : arr) { - JsonMSCJoint item; - item.Load(elem.cast()); - nodes.push_back(std::move(item)); - } - bitmask |= 8; - } - if (auto it = obj.find(ffi::String("prims")); it != obj.end()) { - auto arr = (*it).second.cast(); - prims.clear(); - prims.reserve(arr.size()); - for (const auto& elem : arr) { - JsonMSCPrim item; - item.Load(elem.cast()); - prims.push_back(std::move(item)); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "name, inputs, outputs and nodes should be given"; - } -}; - -/*! - * \brief Json serialize and deserialize for WeightGraph. - * WeightGraph is core of MSC.prune. - * WeightGraph contains WeightJoints as nodes. - */ -struct JsonWeightGraph { - std::string name; - std::vector nodes; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - { - ffi::json::Array arr; - for (const auto& item : nodes) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("nodes"), std::move(arr)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("nodes")); it != obj.end()) { - auto arr = (*it).second.cast(); - nodes.clear(); - nodes.reserve(arr.size()); - for (const auto& elem : arr) { - JsonWeightJoint item; - item.Load(elem.cast()); - nodes.push_back(std::move(item)); - } - bitmask |= 2; - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2) << "name and nodes should be given"; - } -}; - -/*! - * \brief Tensor in MSCGraph. - */ -class MSCTensorNode : public Object { - public: - /*! \brief The name of tensor. */ - ffi::String name; - /*! \brief The alias of tensor, can be changed. */ - mutable ffi::String alias; - /*! \brief The data type of tensor. */ - DataType dtype; - /*! \brief The layout of tensor. */ - tvm::tir::Layout layout; - /*! \brief The shape of tensor. */ - ffi::Array shape; - /*! \brief The prims of tensor. */ - ffi::Array prims; - /*! \brief Export tensor to json. */ - const JsonMSCTensor ToJson() const; - /*! \brief Load tensor from json struct. */ - void FromJson(const JsonMSCTensor& j_tensor); - /*! \brief Load tensor from json string. */ - void FromJson(const std::string& json_str); - /*! \brief Get the ndim of tensor. */ - size_t Ndim() const; - /*! \brief Get dim at given index. */ - const Integer DimAt(int index) const; - /*! \brief Get dim at given axis. */ - const Integer DimAt(const ffi::String& axis) const; - /*! \brief Get prim at given index. */ - const ffi::String PrimAt(int index) const; - /*! \brief Get prim at given axis. */ - const ffi::String PrimAt(const ffi::String& axis) const; - /*! \brief Get layout index of given axis. */ - int32_t LayoutOf(const ffi::String& axis) const; - /*! \brief Get size of the tensor. */ - const Integer GetSize() const; - /*! \brief Get name of the dtype. */ - const ffi::String DTypeName() const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &MSCTensorNode::name) - .def_ro("alias", &MSCTensorNode::alias) - .def_ro("dtype", &MSCTensorNode::dtype) - .def_ro("layout", &MSCTensorNode::layout) - .def_ro("shape", &MSCTensorNode::shape) - .def_ro("prims", &MSCTensorNode::prims); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCTensor", MSCTensorNode, Object); -}; - -/*! - * \brief Managed reference to MSCTensorNode. - * \sa MSCTensorNode - */ -class MSCTensor : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of the tensor. - * \param dtype The data type the tensor. - * \param layout The layout of the tensor. - * \param shape The shape of the tensor. - * \param alias The alias of the tensor. - * \param prims The prims of the tensor shape. - */ - TVM_DLL MSCTensor(const ffi::String& name, const DataType& dtype, const ffi::String& layout, - const ffi::Array& shape, const ffi::String& alias = "", - const ffi::Array& prims = ffi::Array()); - - /*! - * \brief The json constructor. - * \param j_tensor The json describe of the tensor. - */ - TVM_DLL MSCTensor(const JsonMSCTensor& j_tensor); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the tensor. - */ - TVM_DLL MSCTensor(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCTensor, ObjectRef, MSCTensorNode); -}; - -/*! - * \brief Basic node in MSCGraph and WeightGraph. - */ -class BaseJoint; -class BaseJointNode : public Object { - public: - /*! \brief The index of node, can be changed. */ - mutable int index; - /*! \brief The name of node. */ - ffi::String name; - /*! \brief The shared_ref of node, can be changed. */ - ffi::String shared_ref; - /*! \brief The attributes of node. */ - mutable ffi::Map attrs; - /*! \brief The parents of node. */ - ffi::Array parents; - /*! \brief The children of node. */ - mutable ffi::Array children; - /*! \brief Add child to the node. */ - size_t AddChild(const BaseJoint& child) const; - /*! \brief Get parent from the node. */ - const BaseJoint ParentAt(int index) const; - /*! \brief Get child from the node. */ - const BaseJoint ChildAt(int index) const; - /*! \brief Check if has the attribute. */ - bool HasAttr(const ffi::String& key) const; - /*! \brief Get the attribute by type. */ - bool GetAttr(const ffi::String& key, std::string* val) const; - bool GetAttr(const ffi::String& key, int* val) const; - bool GetAttr(const ffi::String& key, int64_t* val) const; - bool GetAttr(const ffi::String& key, float* val) const; - bool GetAttr(const ffi::String& key, bool* val) const; - bool GetAttr(const ffi::String& key, std::vector* val) const; - bool GetAttr(const ffi::String& key, std::vector* val) const; - bool GetAttr(const ffi::String& key, std::vector* val) const; - bool GetAttr(const ffi::String& key, std::vector* val) const; - bool GetAttr(const ffi::String& key, std::vector* val) const; - /*! \brief Check and get the attribute by type. */ - template - const T GetTypeAttr(const ffi::String& key) const { - T val; - TVM_FFI_ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; - return val; - } - template - const std::vector GetTypeArrayAttr(const ffi::String& key) const { - std::vector val; - TVM_FFI_ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; - return val; - } - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("index", &BaseJointNode::index) - .def_ro("name", &BaseJointNode::name) - .def_ro("shared_ref", &BaseJointNode::shared_ref) - .def_ro("attrs", &BaseJointNode::attrs) - .def_ro("parents", &BaseJointNode::parents) - .def_ro("children", &BaseJointNode::children); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - static constexpr const uint32_t _type_child_slots = 2; - TVM_FFI_DECLARE_OBJECT_INFO("msc.core.BaseJoint", BaseJointNode, Object); -}; - -/*! - * \brief Managed reference to BaseJointNode. - * \sa BaseJointNode - */ -class BaseJoint : public ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseJoint, ObjectRef, BaseJointNode); -}; - -/*! - * \brief Node in MSCGraph. - */ -class MSCJoint; -class MSCJointNode : public BaseJointNode { - public: - /*! \brief The op type of node. */ - ffi::String optype; - /*! \brief The scope of node. */ - ffi::Array scope; - /*! \brief The inputs of node, can be changed. */ - ffi::Array> inputs; - /*! \brief The outputs of node. */ - ffi::Array outputs; - /*! \brief The weights of node. */ - ffi::Map weights; - /*! \brief Export node to json. */ - const JsonMSCJoint ToJson() const; - /*! \brief Load node from json struct. */ - void FromJson(const JsonMSCJoint& j_joint, const ffi::Map& nodes); - /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const ffi::Map& nodes); - /*! \brief Get input from the node. */ - const MSCTensor InputAt(int index) const; - /*! \brief Get inputs from the node. */ - const ffi::Array GetInputs() const; - /*! \brief Get output from the node. */ - const MSCTensor OutputAt(int index) const; - /*! \brief Get outputs from the node. */ - const ffi::Array GetOutputs() const; - /*! \brief Get weight from the node. */ - const MSCTensor WeightAt(const ffi::String& wtype) const; - /*! \brief Get parent from the node. */ - const MSCJoint ParentAt(int index) const; - /*! \brief Get child from the node. */ - const MSCJoint ChildAt(int index) const; - /*! \brief Get Producer of the input. */ - const MSCJoint ProducerOf(int index) const; - const MSCJoint ProducerOf(const ffi::String& input_name) const; - const MSCJoint ProducerOf(const MSCTensor& input) const; - /*! \brief Get Producer and out index of the input. */ - const std::pair ProducerAndIdxOf(int index) const; - const std::pair ProducerAndIdxOf(const ffi::String& name) const; - const std::pair ProducerAndIdxOf(const MSCTensor& input) const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("optype", &MSCJointNode::optype) - .def_ro("scope", &MSCJointNode::scope) - .def_ro("inputs", &MSCJointNode::inputs) - .def_ro("outputs", &MSCJointNode::outputs) - .def_ro("weights", &MSCJointNode::weights); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCJoint", MSCJointNode, BaseJointNode); -}; - -/*! - * \brief Managed reference to MSCJointNode. - * \sa MSCJointNode - */ -class MSCJoint : public BaseJoint { - public: - /*! - * \brief The constructor. - * \param index The index of the node. - * \param name The name of the node. - * \param shared_ref The shared_ref of the node. - * \param optype The op type the node. - * \param attrs The attributes of the node. - * \param inputs The inputs of the node. - * \param outputs The outputs of the node. - * \param weights The weights of the node. - */ - TVM_DLL MSCJoint(int index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& optype, const ffi::Map& attrs, - const ffi::Array& scope, - const std::vector>& inputs, - const ffi::Array& outputs, - const ffi::Map& weights); - - /*! - * \brief The json constructor. - * \param j_joint The json describe of the node. - */ - TVM_DLL MSCJoint(const JsonMSCJoint& j_joint, const ffi::Map& nodes); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the node. - */ - TVM_DLL MSCJoint(const std::string& json_str, const ffi::Map& nodes); - - /*! \brief Clone the node. */ - TVM_DLL static const MSCJoint Clone(const MSCJoint& node, - const std::vector>& inputs); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCJoint, BaseJoint, MSCJointNode); -}; - -/*! - * \brief MSCPrim in MSCGraph. - */ -class MSCPrim; -class MSCPrimNode : public BaseJointNode { - public: - /*! \brief The op of prim. */ - ffi::String optype; - /*! \brief Export prim to json. */ - const JsonMSCPrim ToJson() const; - /*! \brief Load prim from json struct. */ - void FromJson(const JsonMSCPrim& j_prim, const ffi::Map& prims); - /*! \brief Load prim from json string. */ - void FromJson(const std::string& json_str, const ffi::Map& prims); - /*! \brief Get parent from the prim. */ - const MSCPrim ParentAt(int index) const; - /*! \brief Get child from the prim. */ - const MSCPrim ChildAt(int index) const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("optype", &MSCPrimNode::optype); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCPrim", MSCPrimNode, BaseJointNode); -}; - -/*! - * \brief Managed reference to MSCPrimNode. - * \sa MSCPrimNode - */ -class MSCPrim : public BaseJoint { - public: - /*! - * \brief The constructor. - * \param index The index of the prim. - * \param name The name of the prim. - * \param optype The optype of the prim. - * \param parents The parents of the prim. - * \param attrs The attributes of the prim. - */ - TVM_DLL MSCPrim( - int index, const ffi::String& name, const ffi::String& optype, - const ffi::Array& parents, - const ffi::Map& attrs = ffi::Map()); - - /*! - * \brief The json constructor. - * \param j_prim The json describe of the prim. - */ - TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const ffi::Map& prims); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the prim. - */ - TVM_DLL MSCPrim(const std::string& json_str, const ffi::Map& prims); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCPrim, BaseJoint, MSCPrimNode); -}; - -/*! - * \brief Node in WeightGraph. - */ -class WeightJoint; -class WeightJointNode : public BaseJointNode { - public: - /*! \brief The weight reference of weight node. */ - ffi::String weight_type; - /*! \brief The weight of weight node. */ - MSCTensor weight; - /*! \brief The friends of weight node. */ - mutable ffi::Array friends; - /*! \brief Export node to json. */ - const JsonWeightJoint ToJson() const; - /*! \brief Load node from json struct. */ - void FromJson(const JsonWeightJoint& j_joint, const ffi::Map& nodes); - /*! \brief Load node from json string. */ - void FromJson(const std::string& json_str, const ffi::Map& nodes); - /*! \brief Get parent from the node. */ - const WeightJoint ParentAt(int index) const; - /*! \brief Get child from the node. */ - const WeightJoint ChildAt(int index) const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("weight_type", &WeightJointNode::weight_type) - .def_ro("weight", &WeightJointNode::weight) - .def_ro("friends", &WeightJointNode::friends); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.WeightJoint", WeightJointNode, BaseJointNode); -}; - -/*! - * \brief Managed reference to WeightJointNode. - * \sa WeightJointNode - */ -class WeightJoint : public BaseJoint { - public: - /*! - * \brief The constructor. - * \param index The index of the node. - * \param name The name of the node. - * \param shared_ref The shared_ref of the node. - * \param weight_type The weight type of the node. - * \param weight The weight tensor of the node. - * \param parents The parents of the node. - * \param attrs The attributes of the node. - * \param friends The friends of the node. - */ - TVM_DLL WeightJoint( - int index, const ffi::String& name, const ffi::String& shared_ref, - const ffi::String& weight_type, const MSCTensor& weight, const ffi::Array parents, - const ffi::Map& attrs = ffi::Map(), - const ffi::Array& friends = ffi::Array()); - - /*! - * \brief The json constructor. - * \param j_joint The json describe of the node. - */ - TVM_DLL WeightJoint(const JsonWeightJoint& j_joint, - const ffi::Map& nodes); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the node. - */ - TVM_DLL WeightJoint(const std::string& json_str, const ffi::Map& nodes); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(WeightJoint, BaseJoint, WeightJointNode); -}; - -/*! - * \brief Basic graph class (MSCGraph and WeightGraph). - */ -class BaseGraphNode : public Object { - public: - /*! \brief The name of graph. */ - ffi::String name; - /*! \brief The node names in graph, can be changed. */ - ffi::Array node_names; - /*! \brief The nodes in graph, can be changed. */ - ffi::Map nodes; - /*! \brief Check if node in the graph. */ - const bool HasNode(const ffi::String& name) const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &BaseGraphNode::name) - .def_ro("nodes", &BaseGraphNode::nodes) - .def_ro("node_names", &BaseGraphNode::node_names); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 2; - TVM_FFI_DECLARE_OBJECT_INFO("msc.core.BaseGraph", BaseGraphNode, Object); -}; - -/*! - * \brief Managed reference to BaseGraphNode. - * \sa BaseGraphNode - */ -class BaseGraph : public ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseGraph, ObjectRef, BaseGraphNode); -}; - -/*! - * \brief MSCGraph. - */ -class MSCGraph; -class MSCGraphNode : public BaseGraphNode { - public: - /*! \brief The shape node names in graph. */ - ffi::Array prim_names; - /*! \brief The shape nodes in graph. */ - ffi::Map prims; - /*! \brief The input names of graph. */ - ffi::Array input_names; - /*! \brief The output names of graph. */ - ffi::Array output_names; - /*! \brief The tensor alias in graph, get by AnalysisGraph. */ - mutable ffi::Map tensor_alias; - /*! \brief The weights in graph, get by AnalysisGraph. */ - ffi::Map> weight_holders; - /*! \brief Export graph to json. */ - const JsonMSCGraph ToJson() const; - /*! \brief Load graph from json. */ - void FromJson(const JsonMSCGraph& json_str); - /*! \brief Load graph from json string. */ - void FromJson(const std::string& json_str); - /*! \brief Export graph to prototxt. */ - const ffi::String ToPrototxt() const; - /*! \brief Find node in graph. */ - const MSCJoint FindNode(const ffi::String& name) const; - /*! \brief Find prim in graph. */ - const MSCPrim FindPrim(const ffi::String& name) const; - /*! \brief Get input from the graph. */ - const MSCTensor InputAt(int index) const; - /*! \brief Get inputs from the graph. */ - const ffi::Array GetInputs() const; - /*! \brief Get output from the graph. */ - const MSCTensor OutputAt(int index) const; - /*! \brief Get outputs from the graph. */ - const ffi::Array GetOutputs() const; - /*! \brief Get entries from the graph. */ - const ffi::Array GetEntries() const; - /*! \brief Get exits from the graph. */ - const ffi::Array GetExits() const; - /*! \brief Check if tensor in the graph. */ - const bool HasTensor(const ffi::String& name) const; - /*! \brief Find tensor from the graph. */ - const MSCTensor FindTensor(const ffi::String& name) const; - /*! \brief Find producer of tensor from the graph. */ - const MSCJoint FindProducer(const ffi::String& name) const; - /*! \brief Find producer of tensor from the graph. */ - const MSCJoint FindProducer(const MSCTensor& tensor) const; - /*! \brief Find producer and output index of tensor from the graph. */ - const std::pair FindProducerAndIdx(const ffi::String& name) const; - /*! \brief Find producer and output index of tensor from the graph. */ - const std::pair FindProducerAndIdx(const MSCTensor& tensor) const; - /*! \brief Find consumers of tensor from the graph. */ - const ffi::Array FindConsumers(const ffi::String& name) const; - /*! \brief Find consumers of tensor from the graph. */ - const ffi::Array FindConsumers(const MSCTensor& tensor) const; - /*! \brief Find consumers and input indices of tensor from the graph. */ - const std::vector> FindConsumersAndIndices( - const ffi::String& name) const; - /*! \brief Find consumers and input indices of tensor from the graph. */ - const std::vector> FindConsumersAndIndices( - const MSCTensor& tensor) const; - /*! \brief Analysis the graph and fill info. */ - void AnalysisGraph(); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("prims", &MSCGraphNode::prims) - .def_ro("prim_names", &MSCGraphNode::prim_names) - .def_ro("input_names", &MSCGraphNode::input_names) - .def_ro("output_names", &MSCGraphNode::output_names) - .def_ro("weight_holders", &MSCGraphNode::weight_holders); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.MSCGraph", MSCGraphNode, BaseGraphNode); -}; - -/*! - * \brief Managed reference to MSCGraphNode. - * \sa MSCGraphNode - */ -class MSCGraph : public BaseGraph { - public: - /*! - * \brief The constructor. - * \param name The name of the node. - * \param nodes The nodes in the graph. - * \param input_names The input names of the graph. - * \param output_names The output names of the graph. - * \param prims The prims in the graph. - */ - TVM_DLL MSCGraph(const ffi::String& name, const ffi::Array& nodes, - const ffi::Array& input_names, - const ffi::Array& output_names, - const ffi::Array& prims = ffi::Array()); - - /*! - * \brief The json constructor. - * \param j_graph The json describe of the graph. - */ - TVM_DLL MSCGraph(const JsonMSCGraph& j_graph); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the graph. - */ - TVM_DLL MSCGraph(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MSCGraph, BaseGraph, MSCGraphNode); -}; - -/*! - * \brief WeightGraph. - */ -class WeightGraphNode : public BaseGraphNode { - public: - /*! \brief build from MSCGraph. */ - void Build(const MSCGraph& graph, - const ffi::Map>& prunable_types, - const ffi::Map& relation_types); - /*! \brief Find node in graph. */ - const WeightJoint FindNode(const ffi::String& name) const; - /*! \brief Export graph to json. */ - const JsonWeightGraph ToJson() const; - /*! \brief Load graph from json. */ - void FromJson(const JsonWeightGraph& json_str); - /*! \brief Load graph from json string. */ - void FromJson(const std::string& json_str); - /*! \brief Export graph to prototxt. */ - const ffi::String ToPrototxt() const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.WeightGraph", WeightGraphNode, BaseGraphNode); -}; - -/*! - * \brief Managed reference to WeightGraphNode. - * \sa WeightGraphNode - */ -class WeightGraph : public BaseGraph { - public: - /*! - * \brief The constructor based on MSCGraph. - * \param graph The msc graph. - * \param prunable_types The prunable types. - * \param relation_types The relation types. - */ - TVM_DLL WeightGraph(const MSCGraph& graph, - const ffi::Map>& prunable_types, - const ffi::Map& relation_types); - - /*! - * \brief The json constructor. - * \param j_graph The json describe of the graph. - */ - TVM_DLL WeightGraph(const JsonWeightGraph& j_graph); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the graph. - */ - TVM_DLL WeightGraph(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(WeightGraph, BaseGraph, WeightGraphNode); -}; - -MSCGraph PruneWeights(const MSCGraph& graph, - const ffi::Map& pruned_tensors); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_IR_GRAPH_H_ diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc deleted file mode 100644 index cc713192d68b..000000000000 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ /dev/null @@ -1,873 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/graph_builder.cc - */ - -#include "graph_builder.h" - -#include - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; - -const std::string GetScalarStr(const runtime::Tensor& data, int float_precision) { - std::string scalar_str; - if (data->dtype.code == kDLFloat) { - const float val = ExprUtils::GetScalar(data); - std::stringstream stream; - stream << std::fixed << std::setprecision(float_precision) << val; - scalar_str = stream.str(); - } else { - const int val = ExprUtils::GetScalar(data); - scalar_str = std::to_string(val); - } - return scalar_str; -} - -void FuncAttrGetter::VisitExpr_(const CallNode* op) { - if (op->attrs.defined()) { - ffi::Map attrs; - AttrGetter getter(&attrs); - getter(op->attrs); - for (const auto& pair : attrs) { - if (attrs_.count(pair.first)) { - int cnt = 1; - ffi::String rep_key = pair.first; - while (attrs_.count(rep_key + "_" + std::to_string(cnt))) { - cnt++; - } - attrs_.Set(pair.first + "_" + std::to_string(cnt), pair.second); - } else { - attrs_.Set(pair.first, pair.second); - } - } - } -} - -void FuncAttrGetter::VisitExpr_(const TupleGetItemNode* op) { - attrs_.Set("index", std::to_string(op->index)); -} - -void FuncValueGetter::VisitExpr_(const CallNode* op) { - for (const auto& arg : op->args) { - if (const auto* s_node = arg.as()) { - values_.push_back(StringUtils::ToString(s_node->value)); - } else if (const auto* s_node = arg.as()) { - bool all_values = std::all_of(s_node->fields.begin(), s_node->fields.end(), - [](const Expr& e) { return e->IsInstance(); }); - if (all_values) { - values_.push_back(StringUtils::ToString(s_node->fields)); - } - } - } -} - -void FuncParamsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, ffi::GetRef(val)); -} - -void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { - ExprVisitor::VisitExpr_(call_node); - Function func; - if (const auto* v_node = call_node->op.as()) { - func = Downcast(ref_module_->Lookup(v_node->name_hint)); - } else if (call_node->op->IsInstance()) { - TVM_FFI_ICHECK(local_funcs_.count(call_node->op)) - << "Can not find local func " << call_node->op; - func = local_funcs_[call_node->op]; - } - if (func.defined()) { - for (size_t i = 0; i < call_node->args.size(); i++) { - const auto& arg = call_node->args[i]; - if (arg->IsInstance() && params_.count(Downcast(arg))) { - params_.Set(func->params[i], params_[Downcast(arg)]); - } else { - params_.Set(func->params[i], arg); - } - } - } -} - -void LayoutsFinder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - local_funcs_.Set(binding->var, ffi::GetRef(val)); -} - -void LayoutsFinder::VisitExpr_(const CallNode* call_node) { - ExprVisitor::VisitExpr_(call_node); - Function func; - if (const auto* v_node = call_node->op.as()) { - func = Downcast(ref_module_->Lookup(v_node->name_hint)); - VisitExpr(func); - } else if (call_node->op->IsInstance()) { - TVM_FFI_ICHECK(local_funcs_.count(call_node->op)) - << "Can not find local func " << call_node->op; - func = local_funcs_[call_node->op]; - } - if (func.defined()) { - const auto& layouts_opt = - func->GetAttr>(msc_attr::kInputLayouts); - if (layouts_opt.defined()) { - for (const auto& pair : layouts_opt.value()) { - layouts_.Set(pair.first, pair.second); - } - } - } -} - -const MSCGraph GraphBuilder::Build(const Function& func) { - // Add input nodes and record inputs; - ffi::Array input_names, output_names; - std::set added_inputs; - // Add prims - for (const auto& p : func->params) { - if (!p->struct_info_.defined()) { - continue; - } - if (p->struct_info_.value()->IsInstance()) { - const auto& shape = ExprUtils::GetShape(p, false); - for (size_t i = 0; i < shape.size(); i++) { - if (shape[i]->IsInstance()) { - ffi::Map attrs; - attrs.Set("producer", p->name_hint()); - attrs.Set("out_idx", "0"); - attrs.Set("dim", std::to_string(i)); - MatchOrCreatePrim(shape[i], "shape", ffi::Array(), attrs); - } - } - } else { - LOG_FATAL << "Unexpected func param " << p << "(" << p->GetTypeKey() << ")"; - } - } - - for (const auto& p : func->params) { - if (expr_tensor_map_.count(p)) { - continue; - } - if (func_params_.count(p) && func_params_[p]->IsInstance()) { - continue; - } - if (func_params_.count(p) && func_params_[p]->IsInstance()) { - const auto& tuple = Downcast(func_params_[p]); - ffi::Array tuple_names; - for (const auto& f : tuple->fields) { - if (expr_tensor_map_.count(f)) { - LOG_INFO << "Replica tuple input " << f; - } else if (const auto* f_node = f.as()) { - AddNode(f, std::nullopt, f_node->name_hint()); - } else { - LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")"; - } - TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find func param from tuple " << f; - for (const auto& name : expr_tensor_map_[f]) { - tuple_names.push_back(name); - } - } - expr_tensor_map_.Set(p, tuple_names); - } else { - AddNode(p, std::nullopt, p->name_hint()); - } - TVM_FFI_ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p; - for (const auto& name : expr_tensor_map_[p]) { - if (!added_inputs.count(name)) { - input_names.push_back(name); - added_inputs.insert(name); - } - } - } - VisitExpr(func); - TVM_FFI_ICHECK(expr_tensor_map_.count(func->body->body)) - << "Can not find seqexpr body " << func->body->body; - output_names = expr_tensor_map_[func->body->body]; - // remove const nodes as weights - ffi::Array valid_nodes; - std::set ignore_inputs; - for (const auto& n : nodes_) { - if (weights_.count(n->name) || ignore_nodes_.count(n->name)) { - for (const auto& o : n->outputs) { - ignore_inputs.insert(o->name); - } - } else { - n->index = valid_nodes.size(); - valid_nodes.push_back(n); - if (n->optype != "input") { - for (const auto& o : n->outputs) { - ignore_inputs.insert(o->name); - } - } - } - } - // remove uselese inputs - ffi::Array valid_inputs; - for (const auto& i : input_names) { - if (!ignore_inputs.count(i)) { - valid_inputs.push_back(i); - } - } - // build graph - const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names, prims_); - // set inputs and outputs alias - if (config_.input_aliases.size() == valid_inputs.size()) { - for (size_t i = 0; i < valid_inputs.size(); i++) { - graph->FindTensor(valid_inputs[i])->alias = config_.input_aliases[i]; - } - } else { - for (size_t i = 0; i < valid_inputs.size(); i++) { - graph->FindTensor(valid_inputs[i])->alias = graph->FindProducer(valid_inputs[i])->name; - } - } - if (config_.output_aliases.size() == output_names.size()) { - for (size_t i = 0; i < output_names.size(); i++) { - graph->FindTensor(output_names[i])->alias = config_.output_aliases[i]; - } - } else { - for (size_t i = 0; i < output_names.size(); i++) { - const auto& output = graph->FindTensor(output_names[i]); - if (output->alias.size() > 0) { - continue; - } - const auto& producer = graph->FindProducer(output_names[i]); - output->alias = producer->outputs.size() == 1 - ? producer->name - : StringUtils::Replace(output_names[i], ":", "_"); - } - } - return graph; -} - -const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional& binding_var, - const ffi::String& name) { - // Get optype, node_name and layout - ffi::String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); - ffi::String optype = "unknown"; - ffi::String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { - node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName); - optype = "constant"; - } else if (expr->IsInstance()) { - optype = "input"; - } else if (expr->IsInstance()) { - optype = "constant"; - } else if (expr->IsInstance()) { - optype = "shape"; - } else if (expr->IsInstance()) { - optype = "get_item"; - } else if (expr->IsInstance()) { - optype = "tuple"; - } else if (const auto* call_node = expr.as()) { - if (const auto* op_node = call_node->op.as()) { - if (op_node->name == "relax.call_dps_packed") { - optype = Downcast(call_node->args[0])->global_symbol; - } else { - optype = StringUtils::Replace(op_node->name, "relax.", ""); - } - } else if (const auto* v_node = call_node->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - std::tie(node_name, optype, layout) = ParseFunc(func); - } else if (call_node->op->IsInstance()) { - TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) - << "Can not find target func: " << call_node->op; - std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]); - } else if (call_node->op->IsInstance()) { - std::tie(node_name, optype, layout) = ParseFunc(Downcast(call_node->op)); - } - } - if (layouts_.count(node_name)) { - layout = layouts_[node_name]; - } - - // specail case for tuple - if (optype == "tuple" && expr->IsInstance() && - Downcast(expr)->op->IsInstance()) { - const auto& call_node = Downcast(expr); - TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) - << "Can not find target func: " << call_node->op; - const auto& tuple_func = target_funcs_[call_node->op]; - for (size_t i = 0; i < call_node->args.size(); i++) { - expr_tensor_map_.Set(tuple_func->params[i], expr_tensor_map_[call_node->args[i]]); - } - VisitExpr(tuple_func); - TVM_FFI_ICHECK(expr_tensor_map_.count(tuple_func->body->body)) - << "Can not find seqexpr body " << tuple_func->body->body; - const auto& outputs = expr_tensor_map_[tuple_func->body->body]; - const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr; - expr_tensor_map_.Set(ref_expr, outputs); - TVM_FFI_ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0]; - return Downcast(tensor_input_map_[outputs[0]].first); - } - - // get plugin - const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin(); - - // Extract normal attributes - ffi::Map attrs; - if (plugin.defined()) { - const auto& op = Downcast(expr)->op; - if (target_funcs_.count(op)) { - const auto& opattrs_opt = - target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); - if (opattrs_opt.defined()) { - const auto& opattrs = opattrs_opt.value(); - TVM_FFI_ICHECK_EQ(opattrs.size(), plugin->attrs.size()) - << "opattrs " << opattrs << " size mismatch with " << plugin->attrs.size(); - for (size_t i = 0; i < opattrs.size(); i++) { - attrs.Set(plugin->attrs[i]->name, opattrs[i]); - } - } - } else { - const auto& args = GetPluginInputs(expr); - for (size_t i = 0; i < plugin->attrs.size(); i++) { - const auto& val = args[plugin->inputs.size() + i]; - attrs.Set(plugin->attrs[i]->name, StringUtils::ToString(val)); - } - } - } else if (const auto* call_node = expr.as()) { - if (const auto* v_node = call_node->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& name_opt = func->GetAttr(relax::attr::kComposite); - if (name_opt.has_value()) { - attrs = FuncAttrGetter().GetAttrs(func); - } - } else if (call_node->op->IsInstance()) { - TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) - << "Can not find target func: " << call_node->op; - attrs = FuncAttrGetter().GetAttrs(target_funcs_[call_node->op]); - } else if (call_node->op->IsInstance()) { - attrs = FuncAttrGetter().GetAttrs(call_node->op); - } else if (call_node->attrs.defined()) { - AttrGetter getter(&attrs); - getter(call_node->attrs); - } - } else if (const auto* const_node = expr.as()) { - if (const_node->is_scalar()) { - attrs.Set("scalar", GetScalarStr(const_node->data, config_.float_precision)); - } - } else if (const auto* shape_node = expr.as()) { - attrs.Set("shape", StringUtils::ToString(shape_node->values)); - } else if (const auto* get_node = expr.as()) { - attrs.Set("index", std::to_string(get_node->index)); - } - - // Extract attributes from arguments - ffi::Array input_types; - if (!plugin.defined() && expr->IsInstance()) { - const auto& call = Downcast(expr); - ffi::Array values; - if (call->op->IsInstance()) { - TVM_FFI_ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; - values = FuncValueGetter().GetValues(target_funcs_[call->op]); - } - input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true); - for (size_t i = 0; i < call->args.size(); i++) { - const auto& arg = call->args[i]; - if (const auto* s_node = arg.as()) { - attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); - } else if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { - const auto* s_node = func_params_[arg].as(); - attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); - ignore_nodes_.insert(Downcast(arg)->name_hint()); - } else if (const auto* s_node = arg.as()) { - TVM_FFI_ICHECK(input_types[i] != "input") - << i << " th PrimValue of " << optype << " should has special type, get " - << input_types; - attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); - } else if (input_types[i] != "input" && arg->IsInstance()) { - attrs.Set(input_types[i], StringUtils::ToString(arg)); - } - } - for (size_t i = call->args.size(); i < input_types.size(); i++) { - attrs.Set(input_types[i], values[i - call->args.size()]); - } - } - - // Build inputs and weights - ffi::Array input_names; - ffi::Map node_weights; - if (plugin.defined()) { - const auto& call = Downcast(expr); - if (call->args.size() == 1) { - TVM_FFI_ICHECK(expr_tensor_map_.count(call->args[0])) - << "Can not find tuple plugin input " << call->args[0]; - input_names = expr_tensor_map_[call->args[0]]; - } else { - const auto& args = GetPluginInputs(expr); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - TVM_FFI_ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i]; - for (const auto& in_name : expr_tensor_map_[args[i]]) { - input_names.push_back(in_name); - } - } - } - } else if (const auto* call_node = expr.as()) { - for (size_t i = 0; i < call_node->args.size(); i++) { - if (attrs.count(input_types[i])) { - continue; - } - const auto& arg = call_node->args[i]; - ffi::Array arg_names; - if (expr_tensor_map_.count(arg)) { - arg_names = expr_tensor_map_[arg]; - } else if (input_types[i] == "input" && arg->IsInstance()) { - const auto* tuple_node = arg.as(); - for (const auto& f : tuple_node->fields) { - TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; - for (const auto& in_name : expr_tensor_map_[f]) { - arg_names.push_back(in_name); - } - } - } - ffi::String weight_name; - if (input_types[i] != "input" && arg->IsInstance()) { - weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); - } else if (input_types[i] != "input" && func_params_.count(arg) && - func_params_[arg]->IsInstance()) { - weight_name = SpanUtils::GetAttr(func_params_[arg]->span, msc_attr::kName); - ignore_nodes_.insert(Downcast(arg)->name_hint()); - } - // set weights or inputs - if (weight_name.size() > 0) { - const auto& t_name = arg_names[0]; - const auto& pair = tensor_input_map_[t_name]; - const auto& producer = Downcast(pair.first); - if (!weights_.count(weight_name)) { - const auto& ref = producer->OutputAt(pair.second); - MSCTensor weight; - if (input_types[i] == "bias") { - weight = MSCTensor(weight_name, ref->dtype, "O", ffi::Array{ref->GetSize()}); - } else if (input_types[i] == "weight" && - (optype == "msc.linear" || optype == "msc.linear_bias")) { - if (ref->layout.name() == "IO") { - ffi::String valid_layout = ref->layout[1].name() + ref->layout[0].name(); - const auto& valid_shape = ffi::Array({ref->shape[1], ref->shape[0]}); - weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape); - } else { - weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); - } - } else { - weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); - } - weights_.Set(weight_name, weight); - } - if (producer->HasAttr("scalar")) { - attrs.Set(input_types[i], producer->GetTypeAttr("scalar")); - } - node_weights.Set(input_types[i], weights_[weight_name]); - } else { - for (const auto& in_name : arg_names) { - input_names.push_back(in_name); - } - } - } - } else if (const auto* tuple_node = expr.as()) { - for (const auto& f : tuple_node->fields) { - TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; - for (const auto& in_name : expr_tensor_map_[f]) { - input_names.push_back(in_name); - } - } - } else if (const auto* getitem_node = expr.as()) { - TVM_FFI_ICHECK(expr_tensor_map_.count(getitem_node->tuple)) - << "Can not find tuple " << getitem_node->tuple; - input_names = expr_tensor_map_[getitem_node->tuple]; - } else if (optype == "constant") { - const auto& t_info = Downcast(GetStructInfo(expr)); - const auto& shape_opt = t_info->GetShape(); - TVM_FFI_ICHECK(shape_opt.defined()) << "Constant shape is not defined"; - const auto& weight = - MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(shape_opt.value())); - node_weights.Set("const", weight); - } - std::vector> inputs; - for (const auto& i : input_names) { - inputs.push_back(tensor_input_map_[i]); - } - - // Redefine layout for special ops - if (optype == "tuple") { - layout = ""; - for (size_t i = 0; i < inputs.size(); i++) { - const auto& in_tensor = Downcast(inputs[i].first)->OutputAt(inputs[i].second); - layout = layout + in_tensor->layout.name(); - layout = layout + (i == inputs.size() - 1 ? "" : ","); - } - } else if (optype == "get_item") { - int idx = std::stoi(attrs["index"]); - const auto& in_tensor = Downcast(inputs[idx].first)->OutputAt(inputs[idx].second); - layout = in_tensor->layout.name(); - } - - // Build output tensor - auto build_output = [this](const StructInfo& sinfo, const ffi::String& node_name, - const ffi::String& layout) { - TVM_FFI_ICHECK(sinfo->IsInstance()) - << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); - const auto& t_info = Downcast(sinfo); - const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); - ffi::Array prims; - bool has_prims = false; - if (shape.size() > 0) { - for (const auto& s : t_info->GetShape().value()) { - if (prim_map_.count(s)) { - prims.push_back(prim_map_[s]->name); - has_prims = true; - } else { - prims.push_back(StringUtils::ToString(s)); - } - } - } - if (has_prims) { - return MSCTensor(node_name, t_info->dtype, layout, shape, "", prims); - } - return MSCTensor(node_name, t_info->dtype, layout, shape); - }; - - // Gather outputs - ffi::Array outputs; - const auto& sinfo = GetStructInfo(expr); - ffi::Array layouts = StringUtils::Split(layout, ","); - size_t num_output = 1; - if (const auto* tuple_sinfo = sinfo.as()) { - num_output = tuple_sinfo->fields.size(); - } - if (layouts.size() == 0) { - layouts = ffi::Array(num_output, ""); - } - TVM_FFI_ICHECK_EQ(layouts.size(), num_output) - << "Layouts " << layouts << " msimatch with output size " << num_output; - if (sinfo->IsInstance()) { - const auto& t_name = node_name + ":" + std::to_string(0); - outputs.push_back(build_output(sinfo, t_name, layouts[0])); - } else if (const auto* s_sinfo = sinfo.as()) { - ffi::Array shape{s_sinfo->ndim}; - const auto& t_name = node_name + ":" + std::to_string(0); - const auto& dtype = DataType(ffi::StringToDLDataType("int32")); - outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); - } else if (const auto* tuple_sinfo = sinfo.as()) { - size_t field_size = optype == "nn.batch_norm" ? 1 : num_output; - for (size_t i = 0; i < field_size; i++) { - const auto& t_name = node_name + ":" + std::to_string(i); - outputs.push_back(build_output(tuple_sinfo->fields[i], t_name, layouts[i])); - } - } else { - TVM_FFI_THROW(InternalError) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" - << sinfo; - } - - // Build node - ffi::Array scope; - if (optype != "input" && optype != "constant") { - scope = StringUtils::Split(scope_name_, "."); - } - const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); - const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs, - outputs, node_weights); - ffi::Array output_names; - for (size_t i = 0; i < outputs.size(); i++) { - output_names.push_back(outputs[i]->name); - tensor_input_map_[outputs[i]->name] = std::make_pair(node, i); - } - nodes_.push_back(node); - const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr; - expr_tensor_map_.Set(ref_expr, output_names); - return node; -} - -void GraphBuilder::VisitBindingBlock(const BindingBlock& block) { - ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); - if (block_name.size() == 0) { - block_name = "block"; - } - const ffi::String& prefix = StringUtils::Join(block_stack_, "."); - if (setted_blocks_.count(prefix + "." + block_name)) { - int cnt = 1; - while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { - cnt++; - } - block_name = block_name + "_" + std::to_string(cnt); - } - scope_name_ = prefix + "." + block_name; - setted_blocks_.insert(scope_name_); - block_stack_.push_back(block_name); - ExprVisitor::VisitBindingBlock(block); - block_stack_.pop_back(); -} - -#define ADD_BINARY_PRIM(TypeName) \ - if (prim->IsInstance()) { \ - const auto& binary = Downcast(prim); \ - return MatchOrCreatePrim(prim, "", {AddPrim(binary->a), AddPrim(binary->b)}); \ - } - -const MSCPrim GraphBuilder::AddPrim(const PrimExpr& prim) { - if (prim_map_.count(prim)) { - return prim_map_[prim]; - } - - // binary - ADD_BINARY_PRIM(tvm::tir::Add) - ADD_BINARY_PRIM(tvm::tir::Sub) - ADD_BINARY_PRIM(tvm::tir::Mul) - ADD_BINARY_PRIM(tvm::tir::Div) - ADD_BINARY_PRIM(tvm::tir::Mod) - ADD_BINARY_PRIM(tvm::tir::FloorDiv) - ADD_BINARY_PRIM(tvm::tir::FloorMod) - ADD_BINARY_PRIM(tvm::tir::Max) - ADD_BINARY_PRIM(tvm::tir::Min) - - // compare - ADD_BINARY_PRIM(tvm::tir::EQ) - ADD_BINARY_PRIM(tvm::tir::NE) - ADD_BINARY_PRIM(tvm::tir::LT) - ADD_BINARY_PRIM(tvm::tir::LE) - ADD_BINARY_PRIM(tvm::tir::GT) - ADD_BINARY_PRIM(tvm::tir::GE) - - // scalar - if (prim->IsInstance()) { - ffi::Map attrs; - attrs.Set("value", StringUtils::ToString(prim)); - return MatchOrCreatePrim(prim, "Int", ffi::Array(), attrs); - } - - // call - if (const auto* c_node = prim.as()) { - ffi::String optype; - ffi::Array parents; - if (const auto* op_node = c_node->op.as()) { - optype = StringUtils::Replace(op_node->name, "tir.", ""); - } else { - optype = "Prim"; - } - for (const auto& a : c_node->args) { - parents.push_back(AddPrim(a)); - } - return MatchOrCreatePrim(prim, optype, parents); - } - return MatchOrCreatePrim(prim); -} - -const MSCPrim GraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const ffi::String& optype, - const ffi::Array& parents, - const ffi::Map& attrs) { - if (prim_map_.count(prim)) { - return prim_map_[prim]; - } - const auto& op_ = - optype.size() == 0 ? StringUtils::Replace(prim->GetTypeKey(), "tir.", "") : optype; - for (const auto& p : prims_) { - if (p->optype != op_ || p->attrs.size() != attrs.size() || - p->parents.size() != parents.size()) { - continue; - } - bool attrs_match = std::all_of(p->attrs.begin(), p->attrs.end(), [&attrs](const auto& pair) { - return attrs.count(pair.first) && attrs[pair.first] == pair.second; - }); - if (!attrs_match) { - continue; - } - bool parents_match = true; - for (size_t i = 0; i < parents.size(); i++) { - if (p->ParentAt(i)->name != parents[i]->name) { - parents_match = false; - break; - } - } - if (!parents_match) { - continue; - } - prim_map_.Set(prim, p); - return p; - } - ffi::String name; - if (const auto* v_node = prim.as()) { - name = v_node->name_hint; - } else { - name = StringUtils::Upper(op_) + "_" + std::to_string(prims_.size()); - } - const auto& node = MSCPrim(prims_.size(), name, op_, parents, attrs); - prims_.push_back(node); - prim_map_.Set(prim, node); - return node; -} - -void GraphBuilder::VisitExpr_(const ConstantNode* op) { - if (!expr_tensor_map_.count(ffi::GetRef(op))) { - AddNode(ffi::GetRef(op)); - } -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { - const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(ffi::GetRef(val), binding->var, name); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { - const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(ffi::GetRef(val), binding->var, name); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { - ExprVisitor::VisitBinding_(binding, call_node); - const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; - try { - AddNode(ffi::GetRef(call_node), binding->var, name); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value - << ", reason: " << err.what(); - throw err; - } -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(ffi::GetRef(val), binding->var, name); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const ffi::String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(ffi::GetRef(val), binding->var, name); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const VarNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const auto& output = ffi::GetRef(val); - TVM_FFI_ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; - expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const auto& output = ffi::GetRef(val); - TVM_FFI_ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; - expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); -} - -void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - const auto& name_opt = val->GetAttr(relax::attr::kComposite); - TVM_FFI_ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; - TVM_FFI_ICHECK(config_.target.size() > 0 && - StringUtils::StartsWith(name_opt.value(), config_.target)) - << "Target should be given for target function"; - target_funcs_.Set(binding->var, ffi::GetRef(val)); -} - -const std::tuple GraphBuilder::ParseFunc( - const Function& func) { - ffi::String node_name, optype, layout; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); - // get node_name - if (name_opt.has_value()) { - node_name = name_opt.value(); - } - // get optype - const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - const auto& composite_opt = func->GetAttr(relax::attr::kComposite); - if (codegen_opt.has_value()) { - optype = codegen_opt.value(); - } else if (optype_opt.has_value()) { - optype = optype_opt.value(); - } else if (composite_opt.has_value()) { - optype = composite_opt.value(); - if (config_.target.size() > 0) { - optype = StringUtils::Replace(composite_opt.value(), config_.target + ".", ""); - } - } - // get layout - const auto& layout_opt = func->GetAttr(msc_attr::kLayout); - if (layout_opt.has_value()) { - layout = layout_opt.value(); - } - return std::make_tuple(node_name, optype, layout); -} - -void GraphBuilder::VisitPrimExpr(const PrimExpr& prim) { - ExprVisitor::VisitPrimExpr(prim); - if (!prim->IsInstance() && !prim->IsInstance()) { - AddPrim(prim); - } -} - -ffi::Array GraphBuilder::GetPluginInputs(const Expr& expr) { - TVM_FFI_ICHECK(expr->IsInstance()) << "plugin expr should be call"; - const auto& call = Downcast(expr); - TVM_FFI_ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; - return Downcast(call->args[1])->fields; -} - -ffi::Map WeightsExtractor::GetWeights(const Function& func) { - VisitExpr(func); - return weights_; -} - -void WeightsExtractor::VisitExpr_(const ConstantNode* op) { - const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); - const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); - const auto& sinfo = GetStructInfo(ffi::GetRef(op)); - TVM_FFI_ICHECK(sinfo->IsInstance()) - << "Constant StrcutInfo should be TensorStructInfo"; - const auto& t_info = Downcast(sinfo); - const auto& opt_shape = t_info->GetShape(); - const auto& shape = - opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : ffi::Array(); - const auto& weight = MSCTensor(name, t_info->dtype, layout, shape); - weights_.Set(weight, op->data); -} - -void WeightsExtractor::VisitExpr_(const CallNode* op) { - ExprVisitor::VisitExpr_(op); - if (const auto* v_node = op->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - VisitExpr(func); - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.BuildFromRelax", - [](const IRModule& module, const ffi::String& entry_name, - const ffi::String& options) -> MSCGraph { - auto builder = GraphBuilder(module, entry_name, options); - const auto& func_name = builder.config().byoc_entry.size() > 0 - ? ffi::String(builder.config().byoc_entry) - : entry_name; - const auto& func = Downcast(module->Lookup(func_name)); - return builder.Build(func); - }) - .def( - "msc.core.GetRelaxWeights", - [](const IRModule& module, const ffi::String& entry_name) -> ffi::Map { - const auto& func = Downcast(module->Lookup(entry_name)); - return WeightsExtractor(module).GetWeights(func); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h deleted file mode 100644 index 536623488a21..000000000000 --- a/src/contrib/msc/core/ir/graph_builder.h +++ /dev/null @@ -1,395 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/graph_builder.h - * \brief Builder of MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_ -#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "../utils.h" -#include "graph.h" -#include "plugin.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; - -using Expr = tvm::RelaxExpr; -using tvm::runtime::Tensor; - -/*! - * \brief Config for building MSCGraph. - * Define the configuration for building MSCGraph - */ -struct MSCRBuildConfig { - bool prune_graph{false}; - bool use_var_name{false}; - int float_precision = 6; - std::string byoc_entry; - std::string sort_by; - std::string target = ""; - std::string graph_name = ""; - std::vector input_aliases; - std::vector output_aliases; - std::unordered_map> input_types; - - void LoadInputTypes(ffi::json::Object obj) { - namespace json = ::tvm::ffi::json; - for (const auto& kv : obj) { - auto arr = kv.second.cast(); - std::vector vec; - vec.reserve(arr.size()); - for (const auto& elem : arr) { - vec.push_back(std::string(elem.cast())); - } - input_types[std::string(kv.first.cast())] = std::move(vec); - } - } - - void Load(ffi::json::Object obj) { - namespace json = ::tvm::ffi::json; - if (auto it = obj.find(ffi::String("prune_graph")); it != obj.end()) { - prune_graph = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("use_var_name")); it != obj.end()) { - use_var_name = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("float_precision")); it != obj.end()) { - float_precision = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("byoc_entry")); it != obj.end()) { - byoc_entry = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("sort_by")); it != obj.end()) { - sort_by = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("target")); it != obj.end()) { - target = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("graph_name")); it != obj.end()) { - graph_name = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("input_aliases")); it != obj.end()) { - auto arr = (*it).second.cast(); - input_aliases.clear(); - input_aliases.reserve(arr.size()); - for (const auto& elem : arr) { - input_aliases.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("output_aliases")); it != obj.end()) { - auto arr = (*it).second.cast(); - output_aliases.clear(); - output_aliases.reserve(arr.size()); - for (const auto& elem : arr) { - output_aliases.push_back(std::string(elem.cast())); - } - } - if (auto it = obj.find(ffi::String("input_types")); it != obj.end()) { - LoadInputTypes((*it).second.cast()); - } - } -}; - -class AttrGetter { - public: - /*! - * \brief Get the attributes as ffi::Map - * \param attrs the attributes. - */ - explicit AttrGetter(ffi::Map* attrs) : attrs_(attrs) {} - - void operator()(const Attrs& attrs) { - if (const auto* dict_attrs = attrs.as()) { - for (const auto& [key, value] : dict_attrs->dict) { - this->VisitAny(key, value); - } - } else { - const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); - if (attrs_tinfo->metadata != nullptr) { - tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); - this->VisitAny(ffi::String(field_info->name), field_value); - }); - } - } - } - - private: - void VisitAny(ffi::String key, Any value) { - switch (value.type_index()) { - case kTVMFFINone: { - attrs_->Set(key, ""); - break; - } - case kTVMFFIBool: { - attrs_->Set(key, std::to_string(value.cast())); - break; - } - case kTVMFFIInt: { - attrs_->Set(key, std::to_string(value.cast())); - break; - } - case kTVMFFIFloat: { - attrs_->Set(key, std::to_string(value.cast())); - break; - } - case kTVMFFIDataType: { - attrs_->Set(key, runtime::DLDataTypeToString(value.cast())); - break; - } - case kTVMFFISmallStr: - case kTVMFFIStr: { - attrs_->Set(key, value.cast()); - break; - } - default: { - if (value.type_index() >= kTVMFFIStaticObjectBegin) { - attrs_->Set(key, StringUtils::ToString(value.cast())); - } else { - TVM_FFI_THROW(InternalError) << "Unsupported type: " << value.type_index(); - } - break; - } - } - } - - private: - ffi::Map* attrs_; -}; - -class FuncAttrGetter : public ExprVisitor { - public: - /*! \brief Get the attributes as ffi::Map*/ - ffi::Map GetAttrs(const Expr& expr) { - VisitExpr(expr); - return attrs_; - } - - void VisitExpr_(const CallNode* op) final; - - void VisitExpr_(const TupleGetItemNode* op) final; - - private: - ffi::Map attrs_; -}; - -class FuncValueGetter : public ExprVisitor { - public: - /*! \brief Get the attributes from prim value as ffi::Map*/ - ffi::Array GetValues(const Expr& expr) { - VisitExpr(expr); - return values_; - } - - void VisitExpr_(const CallNode* op) final; - - private: - ffi::Array values_; -}; - -class FuncParamsFinder : public ExprVisitor { - public: - /*! - * \brief The constructor of FuncParamsFinder - * \param ref_module the reference module. - */ - explicit FuncParamsFinder(const IRModule& ref_module) : ExprVisitor() { - ref_module_ = ref_module; - } - - /*! \brief Find the func params and bind with arguments*/ - ffi::Map FindParams(const Expr& expr) { - VisitExpr(expr); - return params_; - } - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; - - void VisitExpr_(const CallNode* op) final; - - private: - IRModule ref_module_; - ffi::Map params_; - ffi::Map local_funcs_; -}; - -class LayoutsFinder : public ExprVisitor { - public: - /*! - * \brief The constructor of LayoutsFinder - * \param ref_module the reference module. - */ - explicit LayoutsFinder(const IRModule& ref_module) : ExprVisitor() { ref_module_ = ref_module; } - - /*! \brief Find the layouts form attrs*/ - ffi::Map FindLayouts(const Expr& expr) { - VisitExpr(expr); - return layouts_; - } - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; - - void VisitExpr_(const CallNode* op) final; - - private: - IRModule ref_module_; - ffi::Map layouts_; - ffi::Map local_funcs_; -}; - -class GraphBuilder : public ExprVisitor { - public: - /*! - * \brief The constructor of GraphBuilder - * \param ref_module the reference module. - * \param name the name of the graph. - * \param options the options of build the graph. - */ - explicit GraphBuilder(const IRModule& ref_module, const ffi::String& name, - const std::string& options = "") - : ExprVisitor() { - ref_module_ = ref_module; - if (options.size() > 0) { - namespace json = ::tvm::ffi::json; - config_.Load(json::Parse(options).cast()); - } - name_ = config_.graph_name.size() > 0 ? ffi::String(config_.graph_name) : name; - if (config_.byoc_entry.size() > 0) { - func_params_ = FuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); - } - layouts_ = LayoutsFinder(ref_module).FindLayouts(ref_module->Lookup(name)); - } - - /*! \brief Build MSCGraph from relax function*/ - const MSCGraph Build(const Function& func); - - /*! \brief Get the config of builder */ - const MSCRBuildConfig config() { return config_; } - - /*! \brief Create and add MSCJoint from expr*/ - const MSCJoint AddNode(const Expr& expr, const ffi::Optional& binding_var = std::nullopt, - const ffi::String& name = ""); - - /*! \brief Create and add MSCPrim from prim*/ - const MSCPrim AddPrim(const PrimExpr& prim); - - const MSCPrim MatchOrCreatePrim( - const PrimExpr& prim, const ffi::String& op = "", - const ffi::Array& parents = ffi::Array(), - const ffi::Map& attrs = ffi::Map()); - - void VisitBindingBlock(const BindingBlock& block) final; - - void VisitExpr_(const ConstantNode* op) final; - - void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final; - - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const VarNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) final; - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final; - - void VisitPrimExpr(const PrimExpr& prim) final; - - private: - /*! \brief Get the node_name, optype, layout for func*/ - const std::tuple ParseFunc(const Function& func); - - /*! \brief Get the plugin inputs*/ - ffi::Array GetPluginInputs(const Expr& expr); - - ffi::String name_; - IRModule ref_module_; - MSCRBuildConfig config_; - ffi::Map layouts_; - ffi::Array nodes_; - ffi::Map weights_; - ffi::Map> expr_tensor_map_; - std::unordered_map> tensor_input_map_; - std::set ignore_nodes_; - // scope name - ffi::String scope_name_; - std::set setted_blocks_; - ffi::Array block_stack_; - // BYOC maps - ffi::Map target_funcs_; - ffi::Map func_params_; - // prims - ffi::Array prims_; - ffi::Map prim_map_; -}; - -class WeightsExtractor : public ExprVisitor { - public: - /*! - * \brief The constructor of GraphBuilder - * \param ref_module the reference module. - * \param name the name of the graph. - * \param options the options of build the graph. - */ - explicit WeightsExtractor(const IRModule& ref_module) : ExprVisitor() { - ref_module_ = ref_module; - } - - /*! \brief Visit the constant and save weights */ - ffi::Map GetWeights(const Function& func); - - void VisitExpr_(const ConstantNode* op) final; - - void VisitExpr_(const CallNode* op) final; - - private: - ffi::Map weights_; - ffi::Map local_funcs_; - IRModule ref_module_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_ diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc deleted file mode 100644 index e0f816f5ca36..000000000000 --- a/src/contrib/msc/core/ir/plugin.cc +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/plugin.cc - */ - -#include "plugin.h" - -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -namespace json = ::tvm::ffi::json; - -PluginAttr::PluginAttr(const ffi::String& name, const ffi::String& type, - const ffi::String& default_value, const ffi::String& describe) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - n->type = std::move(type); - n->default_value = std::move(default_value); - n->describe = std::move(describe); - data_ = std::move(n); -} - -PluginAttr::PluginAttr(const JsonPluginAttr& j_attr) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_attr); - data_ = std::move(n); -} - -PluginAttr::PluginAttr(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonPluginAttr PluginAttrNode::ToJson() const { - JsonPluginAttr j_attr; - j_attr.name = name; - j_attr.type = type; - j_attr.default_value = default_value; - j_attr.describe = describe; - return j_attr; -} - -void PluginAttrNode::FromJson(const JsonPluginAttr& j_attr) { - name = j_attr.name; - type = j_attr.type; - default_value = j_attr.default_value; - describe = j_attr.describe; -} - -void PluginAttrNode::FromJson(const std::string& json_str) { - auto parsed = json::Parse(json_str); - JsonPluginAttr j_attr; - j_attr.Load(parsed.cast()); - FromJson(j_attr); -} - -PluginTensor::PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, - const ffi::String& device, const ffi::String& describe) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - n->dtype = std::move(dtype); - n->ndim = std::move(ndim); - n->device = std::move(device); - n->describe = std::move(describe); - data_ = std::move(n); -} - -PluginTensor::PluginTensor(const JsonPluginTensor& j_tensor) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_tensor); - data_ = std::move(n); -} - -PluginTensor::PluginTensor(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonPluginTensor PluginTensorNode::ToJson() const { - JsonPluginTensor j_tensor; - j_tensor.name = name; - j_tensor.dtype = dtype; - j_tensor.ndim = ndim->value; - j_tensor.device = device; - j_tensor.describe = describe; - return j_tensor; -} - -void PluginTensorNode::FromJson(const JsonPluginTensor& j_tensor) { - name = j_tensor.name; - dtype = j_tensor.dtype; - ndim = Integer(j_tensor.ndim); - device = j_tensor.device; - describe = j_tensor.describe; -} - -void PluginTensorNode::FromJson(const std::string& json_str) { - auto parsed = json::Parse(json_str); - JsonPluginTensor j_tensor; - j_tensor.Load(parsed.cast()); - FromJson(j_tensor); -} - -PluginExtern::PluginExtern(const ffi::String& name, const ffi::String& header, - const ffi::String& source, const ffi::String& lib, - const ffi::String& describe) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - n->header = std::move(header); - n->source = std::move(source); - n->lib = std::move(lib); - n->describe = std::move(describe); - data_ = std::move(n); -} - -PluginExtern::PluginExtern(const JsonPluginExtern& j_extern) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_extern); - data_ = std::move(n); -} - -PluginExtern::PluginExtern(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonPluginExtern PluginExternNode::ToJson() const { - JsonPluginExtern j_extern; - j_extern.name = name; - j_extern.header = header; - j_extern.source = source; - j_extern.lib = lib; - j_extern.describe = describe; - return j_extern; -} - -void PluginExternNode::FromJson(const JsonPluginExtern& j_extern) { - name = j_extern.name; - header = j_extern.header; - source = j_extern.source; - lib = j_extern.lib; - describe = j_extern.describe; -} - -void PluginExternNode::FromJson(const std::string& json_str) { - auto parsed = json::Parse(json_str); - JsonPluginExtern j_extern; - j_extern.Load(parsed.cast()); - FromJson(j_extern); -} - -Plugin::Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, - const ffi::Array& attrs, const ffi::Array& inputs, - const ffi::Array& outputs, const ffi::Array& buffers, - const ffi::Map& externs, - const ffi::Map>& support_dtypes, - const ffi::Map& options) { - ObjectPtr n = ffi::make_object(); - n->name = std::move(name); - n->version = std::move(version); - n->describe = std::move(describe); - n->attrs = std::move(attrs); - n->inputs = std::move(inputs); - n->outputs = std::move(outputs); - n->buffers = std::move(buffers); - n->externs = std::move(externs); - n->support_dtypes = std::move(support_dtypes); - n->options = std::move(options); - data_ = std::move(n); -} - -Plugin::Plugin(const JsonPlugin& j_plugin) { - ObjectPtr n = ffi::make_object(); - n->FromJson(j_plugin); - data_ = std::move(n); -} - -Plugin::Plugin(const std::string& json_str) { - ObjectPtr n = ffi::make_object(); - n->FromJson(json_str); - data_ = std::move(n); -} - -const JsonPlugin PluginNode::ToJson() const { - JsonPlugin j_plugin; - j_plugin.name = name; - j_plugin.version = version; - j_plugin.describe = describe; - for (const auto& a : attrs) { - j_plugin.attrs.push_back(a->ToJson()); - } - for (const auto& t : inputs) { - j_plugin.inputs.push_back(t->ToJson()); - } - for (const auto& t : outputs) { - j_plugin.inputs.push_back(t->ToJson()); - } - for (const auto& t : buffers) { - j_plugin.inputs.push_back(t->ToJson()); - } - for (const auto& pair : externs) { - j_plugin.externs[pair.first] = pair.second->ToJson(); - } - for (const auto& pair : support_dtypes) { - std::vector dtypes; - for (const auto& d : pair.second) { - dtypes.push_back(d); - } - j_plugin.support_dtypes[pair.first] = dtypes; - } - for (const auto& pair : options) { - j_plugin.options[pair.first] = pair.second; - } - return j_plugin; -} - -void PluginNode::FromJson(const JsonPlugin& j_plugin) { - name = j_plugin.name; - version = j_plugin.version; - describe = j_plugin.describe; - for (const auto& a : j_plugin.attrs) { - attrs.push_back(PluginAttr(a)); - } - for (const auto& t : j_plugin.inputs) { - inputs.push_back(PluginTensor(t)); - } - for (const auto& t : j_plugin.outputs) { - outputs.push_back(PluginTensor(t)); - } - for (const auto& t : j_plugin.buffers) { - buffers.push_back(PluginTensor(t)); - } - for (const auto& pair : j_plugin.externs) { - externs.Set(pair.first, PluginExtern(pair.second)); - } - for (const auto& pair : j_plugin.support_dtypes) { - ffi::Array dtypes; - for (const auto& d : pair.second) { - dtypes.push_back(d); - } - support_dtypes.Set(pair.first, dtypes); - } - for (const auto& pair : j_plugin.options) { - options.Set(pair.first, pair.second); - } -} - -void PluginNode::FromJson(const std::string& json_str) { - auto parsed = json::Parse(json_str); - JsonPlugin j_plugin; - j_plugin.Load(parsed.cast()); - FromJson(j_plugin); -} - -int PluginNode::FindDtypeRefIdx(const PluginTensor& tensor) const { - for (size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->dtype == tensor->dtype) { - return i; - } - } - return -1; -} - -int PluginNode::FindDeviceRefIdx(const PluginTensor& tensor) const { - for (size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->device == tensor->device) { - return i; - } - } - return -1; -} - -const ffi::Array ListPluginNames() { return PluginRegistry::Global()->ListAllNames(); } - -const Plugin GetPlugin(const ffi::String& name) { return PluginRegistry::Global()->Get(name); } - -bool IsPlugin(const ffi::String& name) { return PluginRegistry::Global()->Registered(name); } - -TVM_FFI_STATIC_INIT_BLOCK() { - PluginAttrNode::RegisterReflection(); - PluginTensorNode::RegisterReflection(); - PluginExternNode::RegisterReflection(); - PluginNode::RegisterReflection(); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.RegisterPlugin", - [](const ffi::String& name, const ffi::String& json_str) { - PluginRegistry::Global()->Register(name, json_str); - }) - .def("msc.core.ListPluginNames", - []() -> ffi::Array { return ListPluginNames(); }) - .def("msc.core.GetPlugin", [](const ffi::String& name) -> Plugin { return GetPlugin(name); }) - .def("msc.core.IsPlugin", - [](const ffi::String& name) -> Bool { return Bool(IsPlugin(name)); }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h deleted file mode 100644 index 46b41d4fe6f2..000000000000 --- a/src/contrib/msc/core/ir/plugin.h +++ /dev/null @@ -1,739 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/ir/plugin.h - * \brief Plugin describe for msc. - */ -#ifndef TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ -#define TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ - -#include -#include -#include - -#include -#include -#include - -#include "../../../../ir/attr_registry.h" -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief Json serialize and deserialize for Plugin Attribute. - */ -struct JsonPluginAttr { - std::string name; - std::string type; - std::string default_value; - std::string describe; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("type"), ffi::String(type)); - obj.Set(ffi::String("default_value"), ffi::String(default_value)); - obj.Set(ffi::String("describe"), ffi::String(describe)); - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("type")); it != obj.end()) { - type = std::string((*it).second.cast()); - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("default_value")); it != obj.end()) { - default_value = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { - describe = std::string((*it).second.cast()); - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2) << "name and type should be given for plugin attr"; - if (describe.size() == 0) { - describe = "Plugin attribute " + name + "(" + type + ")"; - } - } -}; - -/*! - * \brief Json serialize and deserialize for Plugin Tensor. - */ -struct JsonPluginTensor { - std::string name; - std::string dtype; - int64_t ndim{-1}; - std::string device{"default"}; - std::string describe; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("dtype"), ffi::String(dtype)); - obj.Set(ffi::String("ndim"), ndim); - obj.Set(ffi::String("device"), ffi::String(device)); - obj.Set(ffi::String("describe"), ffi::String(describe)); - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("dtype")); it != obj.end()) { - dtype = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("ndim")); it != obj.end()) { - ndim = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("device")); it != obj.end()) { - device = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { - describe = std::string((*it).second.cast()); - } - TVM_FFI_ICHECK_EQ(bitmask, 1) << "name should be given for plugin tensor"; - if (describe.size() == 0) { - describe = "Plugin tensor " + name + "(" + dtype + " on " + device + ")"; - } - } -}; - -/*! - * \brief Json serialize and deserialize for Plugin Extern. - */ -struct JsonPluginExtern { - std::string name; - std::string header; - std::string source; - std::string lib; - std::string describe; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("header"), ffi::String(header)); - obj.Set(ffi::String("source"), ffi::String(source)); - obj.Set(ffi::String("lib"), ffi::String(lib)); - obj.Set(ffi::String("describe"), ffi::String(describe)); - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("header")); it != obj.end()) { - header = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("source")); it != obj.end()) { - source = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("lib")); it != obj.end()) { - lib = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { - describe = std::string((*it).second.cast()); - } - TVM_FFI_ICHECK_EQ(bitmask, 1) << "name should be given for plugin extern"; - if (describe.size() == 0) { - describe = "Plugin function " + name + "(from " + header + ")"; - } - } -}; - -/*! - * \brief Json serialize and deserialize for Plugin. - */ -struct JsonPlugin { - std::string name; - std::string version; - std::string describe; - std::vector attrs; - std::vector inputs; - std::vector outputs; - std::vector buffers; - std::unordered_map externs; - std::unordered_map> support_dtypes; - std::unordered_map options; - - ffi::json::Value SaveToJSON() const { - ffi::json::Object obj; - obj.Set(ffi::String("name"), ffi::String(name)); - obj.Set(ffi::String("version"), ffi::String(version)); - obj.Set(ffi::String("describe"), ffi::String(describe)); - { - ffi::json::Array arr; - for (const auto& item : attrs) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("attrs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : inputs) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("inputs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : outputs) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("outputs"), std::move(arr)); - } - { - ffi::json::Array arr; - for (const auto& item : buffers) { - arr.push_back(item.SaveToJSON()); - } - obj.Set(ffi::String("buffers"), std::move(arr)); - } - { - ffi::json::Object inner; - for (const auto& kv : externs) { - inner.Set(ffi::String(kv.first), kv.second.SaveToJSON()); - } - obj.Set(ffi::String("externs"), std::move(inner)); - } - // support_dtypes: map> - { - ffi::json::Object sd_obj; - for (const auto& kv : support_dtypes) { - ffi::json::Array arr; - for (const auto& s : kv.second) { - arr.push_back(ffi::String(s)); - } - sd_obj.Set(ffi::String(kv.first), std::move(arr)); - } - obj.Set(ffi::String("support_dtypes"), std::move(sd_obj)); - } - { - ffi::json::Object inner; - for (const auto& kv : options) { - inner.Set(ffi::String(kv.first), ffi::String(kv.second)); - } - obj.Set(ffi::String("options"), std::move(inner)); - } - return obj; - } - - void Load(ffi::json::Object obj) { - int bitmask = 0; - if (auto it = obj.find(ffi::String("name")); it != obj.end()) { - name = std::string((*it).second.cast()); - bitmask |= 1; - } - if (auto it = obj.find(ffi::String("version")); it != obj.end()) { - version = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { - describe = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("attrs")); it != obj.end()) { - auto arr = (*it).second.cast(); - attrs.clear(); - attrs.reserve(arr.size()); - for (const auto& elem : arr) { - JsonPluginAttr item; - item.Load(elem.cast()); - attrs.push_back(std::move(item)); - } - } - if (auto it = obj.find(ffi::String("inputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - inputs.clear(); - inputs.reserve(arr.size()); - for (const auto& elem : arr) { - JsonPluginTensor item; - item.Load(elem.cast()); - inputs.push_back(std::move(item)); - } - bitmask |= 2; - } - if (auto it = obj.find(ffi::String("outputs")); it != obj.end()) { - auto arr = (*it).second.cast(); - outputs.clear(); - outputs.reserve(arr.size()); - for (const auto& elem : arr) { - JsonPluginTensor item; - item.Load(elem.cast()); - outputs.push_back(std::move(item)); - } - bitmask |= 4; - } - if (auto it = obj.find(ffi::String("buffers")); it != obj.end()) { - auto arr = (*it).second.cast(); - buffers.clear(); - buffers.reserve(arr.size()); - for (const auto& elem : arr) { - JsonPluginTensor item; - item.Load(elem.cast()); - buffers.push_back(std::move(item)); - } - } - if (auto it = obj.find(ffi::String("externs")); it != obj.end()) { - auto inner = (*it).second.cast(); - externs.clear(); - for (const auto& kv : inner) { - JsonPluginExtern item; - item.Load(kv.second.cast()); - externs[std::string(kv.first.cast())] = std::move(item); - } - } - // support_dtypes: map> - { - auto it = obj.find(ffi::String("support_dtypes")); - if (it != obj.end()) { - auto inner = (*it).second.cast(); - support_dtypes.clear(); - for (const auto& kv : inner) { - auto arr = kv.second.cast(); - std::vector vec; - vec.reserve(arr.size()); - for (const auto& elem : arr) { - vec.push_back(std::string(elem.cast())); - } - support_dtypes[std::string(kv.first.cast())] = std::move(vec); - } - } - } - if (auto it = obj.find(ffi::String("options")); it != obj.end()) { - auto inner = (*it).second.cast(); - options.clear(); - for (const auto& kv : inner) { - options[std::string(kv.first.cast())] = - std::string(kv.second.cast()); - } - } - TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, inputs and outputs should be given for plugin"; - if (externs.size() > 0) { - TVM_FFI_ICHECK(externs.count("infer_output")) << "infer_output should be given as extern"; - bool has_compute = false; - for (const auto& pair : externs) { - if (StringUtils::EndsWith(pair.first, "_compute")) { - has_compute = true; - } - } - TVM_FFI_ICHECK(has_compute) << "No compute function found, please check"; - } - if (describe.size() == 0) { - describe = "Plugin " + name + "(" + version + ")"; - } - } -}; - -/*! - * \brief Attribute in Plugin. - */ -class PluginAttrNode : public Object { - public: - /*! \brief The name of attribute. */ - ffi::String name; - /*! \brief The type of attribute. */ - ffi::String type; - /*! \brief The default_value of attribute. */ - ffi::String default_value; - /*! \brief The describe of attribute. */ - ffi::String describe; - - /*! \brief Export attribute to json. */ - const JsonPluginAttr ToJson() const; - /*! \brief Load attribute from json struct. */ - void FromJson(const JsonPluginAttr& j_attr); - /*! \brief Load attribute from json string. */ - void FromJson(const std::string& json_str); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &PluginAttrNode::name) - .def_ro("type", &PluginAttrNode::type) - .def_ro("default_value", &PluginAttrNode::default_value) - .def_ro("describe", &PluginAttrNode::describe); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginAttr", PluginAttrNode, Object); -}; - -/*! - * \brief Managed reference to PluginAttrNode. - * \sa PluginAttrNode - */ -class PluginAttr : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of the attribute. - * \param type The type of the attribute. - * \param default_value The default_value of the attribute. - * \param describe The describe of the attribute. - */ - TVM_DLL PluginAttr(const ffi::String& name, const ffi::String& type, - const ffi::String& default_value, const ffi::String& describe); - - /*! - * \brief The json constructor. - * \param j_attr The json describe of the attribute. - */ - TVM_DLL PluginAttr(const JsonPluginAttr& j_attr); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the attribute. - */ - TVM_DLL PluginAttr(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginAttr, ObjectRef, PluginAttrNode); -}; - -/*! - * \brief Tensor in Plugin. - */ -class PluginTensorNode : public Object { - public: - /*! \brief The name of tensor. */ - ffi::String name; - /*! \brief The dtype of tensor. */ - ffi::String dtype; - /*! \brief The ndim of tensor. */ - Integer ndim; - /*! \brief The device of tensor. */ - ffi::String device; - /*! \brief The describe of tensor. */ - ffi::String describe; - - /*! \brief Export tensor to json. */ - const JsonPluginTensor ToJson() const; - /*! \brief Load tensor from json struct. */ - void FromJson(const JsonPluginTensor& j_attr); - /*! \brief Load tensor from json string. */ - void FromJson(const std::string& json_str); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &PluginTensorNode::name) - .def_ro("dtype", &PluginTensorNode::dtype) - .def_ro("ndim", &PluginTensorNode::ndim) - .def_ro("device", &PluginTensorNode::device) - .def_ro("describe", &PluginTensorNode::describe); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginTensor", PluginTensorNode, Object); -}; - -/*! - * \brief Managed reference to PluginTensorNode. - * \sa PluginTensorNode - */ -class PluginTensor : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of the tensor. - * \param dtype The dtype of the tensor. - * \param ndim The ndim of the tensor. - * \param device The device of the tensor. - * \param describe The describe of the tensor. - */ - TVM_DLL PluginTensor(const ffi::String& name, const ffi::String& dtype, const Integer& ndim, - const ffi::String& device, const ffi::String& describe); - - /*! - * \brief The json constructor. - * \param j_tensor The json describe of the tensor. - */ - TVM_DLL PluginTensor(const JsonPluginTensor& j_tensor); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the tensor. - */ - TVM_DLL PluginTensor(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginTensor, ObjectRef, PluginTensorNode); -}; - -/*! - * \brief Extern symbol in Plugin. - */ -class PluginExternNode : public Object { - public: - /*! \brief The name of extern. */ - ffi::String name; - /*! \brief The header of extern. */ - ffi::String header; - /*! \brief The source of extern. */ - ffi::String source; - /*! \brief The lib of extern. */ - ffi::String lib; - /*! \brief The describe of extern. */ - ffi::String describe; - - /*! \brief Export extern to json. */ - const JsonPluginExtern ToJson() const; - /*! \brief Load extern from json struct. */ - void FromJson(const JsonPluginExtern& j_attr); - /*! \brief Load extern from json string. */ - void FromJson(const std::string& json_str); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &PluginExternNode::name) - .def_ro("header", &PluginExternNode::header) - .def_ro("source", &PluginExternNode::source) - .def_ro("lib", &PluginExternNode::lib) - .def_ro("describe", &PluginExternNode::describe); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.PluginExtern", PluginExternNode, Object); -}; - -/*! - * \brief Managed reference to PluginExternNode. - * \sa PluginExternNode - */ -class PluginExtern : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of the extern. - * \param header The header of the extern. - * \param source The source of the extern. - * \param lib The lib of the extern. - * \param describe The describe of the extern. - */ - TVM_DLL PluginExtern(const ffi::String& name, const ffi::String& header, - const ffi::String& source, const ffi::String& lib, - const ffi::String& describe); - - /*! - * \brief The json constructor. - * \param j_extern The json describe of the extern. - */ - TVM_DLL PluginExtern(const JsonPluginExtern& j_extern); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the extern. - */ - TVM_DLL PluginExtern(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PluginExtern, ObjectRef, PluginExternNode); -}; - -/*! - * \brief The Plugin in MSC. - */ -class PluginNode : public Object { - public: - /*! \brief The name of plugin. */ - ffi::String name; - /*! \brief The version of plugin. */ - ffi::String version; - /*! \brief The describe of plugin. */ - ffi::String describe; - /*! \brief The attributes of plugin. */ - ffi::Array attrs; - /*! \brief The inputs of plugin. */ - ffi::Array inputs; - /*! \brief The outputs of plugin. */ - ffi::Array outputs; - /*! \brief The buffers of plugin. */ - ffi::Array buffers; - /*! \brief The externs of plugin. */ - ffi::Map externs; - /*! \brief The support_dtypes of plugin. */ - ffi::Map> support_dtypes; - /*! \brief The options of plugin. */ - ffi::Map options; - - /*! \brief Export plugin to json. */ - const JsonPlugin ToJson() const; - /*! \brief Load plugin from json struct. */ - void FromJson(const JsonPlugin& j_attr); - /*! \brief Load plugin from json string. */ - void FromJson(const std::string& json_str); - - /*! \brief Find input ref index for dtype. */ - int FindDtypeRefIdx(const PluginTensor& tensor) const; - /*! \brief Find input ref index for device. */ - int FindDeviceRefIdx(const PluginTensor& tensor) const; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &PluginNode::name) - .def_ro("version", &PluginNode::version) - .def_ro("describe", &PluginNode::describe) - .def_ro("attrs", &PluginNode::attrs) - .def_ro("inputs", &PluginNode::inputs) - .def_ro("outputs", &PluginNode::outputs) - .def_ro("buffers", &PluginNode::buffers) - .def_ro("externs", &PluginNode::externs) - .def_ro("support_dtypes", &PluginNode::support_dtypes) - .def_ro("options", &PluginNode::options); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.core.Plugin", PluginNode, Object); -}; - -/*! - * \brief Managed reference to PluginNode. - * \sa PluginNode - */ -class Plugin : public ObjectRef { - public: - /*! - * \brief The constructor. - * \param name The name of the plugin. - * \param version The version of the plugin. - * \param describe The describe of the plugin. - * \param attrs The attrs of the plugin. - * \param inputs The inputs of the plugin. - * \param outputs The outputs of the plugin. - * \param buffers The buffers of the plugin. - * \param externs The externs of the plugin. - * \param support_dtypes The support_dtypes of the plugin. - * \param options The options of the plugin. - */ - TVM_DLL Plugin(const ffi::String& name, const ffi::String& version, const ffi::String& describe, - const ffi::Array& attrs, const ffi::Array& inputs, - const ffi::Array& outputs, const ffi::Array& buffers, - const ffi::Map& externs, - const ffi::Map>& support_dtypes, - const ffi::Map& options); - - /*! - * \brief The json constructor. - * \param j_plugin The json describe of the plugin. - */ - TVM_DLL Plugin(const JsonPlugin& j_plugin); - - /*! - * \brief The json constructor. - * \param json_str The json describe of the plugin. - */ - TVM_DLL Plugin(const std::string& json_str); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Plugin, ObjectRef, PluginNode); -}; - -class PluginRegistry { - public: - /*! - * \brief Register a new plugin. - * \param name The name of the item. - * \param json_str The json_str. - * \return The corresponding entry. - */ - bool Register(const ffi::String& name, const ffi::String& json_str) { - plugin_map_[name] = Plugin(json_str); - return true; - } - - /*! - * \brief Check if an plugin is registered. - * \param name The name of the item. - * \return Whether the plugin is registered. - */ - bool Registered(const ffi::String& name) const { - auto it = plugin_map_.find(name); - return it != plugin_map_.end(); - } - - /*! - * \brief Get an plugin from the registry. - * \param name The name of the item. - * \return The corresponding plugin. - */ - const Plugin Get(const ffi::String& name) const { - auto it = plugin_map_.find(name); - TVM_FFI_ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; - return it->second; - } - - /*! - * \brief List all the plugin names in the registry. - * \return The plugin names. - */ - ffi::Array ListAllNames() const { - ffi::Array names; - for (const auto& kv : plugin_map_) { - names.push_back(kv.first); - } - return names; - } - - /*! - * \return a global singleton of the registry. - */ - static PluginRegistry* Global() { - static PluginRegistry* inst = new PluginRegistry(); - return inst; - } - - private: - // map from name to plugins. - std::unordered_map plugin_map_; -}; - -/*! - * \brief List all plugin names. - * \return the corresponding plugin names. - */ -const ffi::Array ListPluginNames(); - -/*! - * \brief Get the registered plugin. - * \param name The name of the Plugin. - * \return the corresponding plugin. - */ -const Plugin GetPlugin(const ffi::String& name); - -/*! - * \brief Check if an plugin is registered. - * \param name The name of the item. - * \return Whether the plugin is registered. - */ -bool IsPlugin(const ffi::String& name); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc deleted file mode 100644 index 54de66638c06..000000000000 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ /dev/null @@ -1,363 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/cpp_printer.cc - */ - -#include "cpp_printer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ffi::Any& value = doc->value; - bool defined = false; - if (value == nullptr) { - output_ << "nullptr"; - defined = true; - } else if (const auto* int_imm = value.as()) { - if (int_imm->dtype.is_bool()) { - output_ << (int_imm->value ? "true" : "false"); - defined = true; - } - } - if (!defined) { - MSCBasePrinter::PrintTypedDoc(doc); - } -} - -void CppPrinter::PrintTypedDoc(const IndexDoc& doc) { - TVM_FFI_ICHECK(doc->indices.size() == 1) << "CppPrinter only support 1 size indices"; - PrintDoc(doc->value, false); - output_ << "["; - PrintDoc(doc->indices[0], false); - output_ << "]"; -} - -void CppPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { - PrintDoc(doc->value, false); - if (StringUtils::EndsWith(doc->name, DocSymbol::NextLine())) { - const auto& v_name = StringUtils::Replace(doc->name, DocSymbol::NextLine(), ""); - if (!doc->value->IsInstance()) { - IncreaseIndent(); - PrintDoc(IdDoc(".")); - DecreaseIndent(); - } - output_ << v_name; - } else { - if (!doc->value->IsInstance()) { - output_ << "."; - } - output_ << doc->name; - } -} - -void CppPrinter::PrintTypedDoc(const CallDoc& doc) { - EnterEndlineScope(false); - PrintDoc(doc->callee, false); - output_ << "("; - PrintJoinedDocs(doc->args); - TVM_FFI_ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) - << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; - if (doc->args.size() > 0 && doc->kwargs_keys.size() > 0) { - output_ << ", "; - } - PrintJoinedDocs(doc->kwargs_values); - output_ << ")"; - ExitEndlineScope(); - Endline(); -} - -void CppPrinter::PrintTypedDoc(const AssignDoc& doc) { - TVM_FFI_ICHECK(doc->lhs.defined()) << "lhs should be given for assign"; - if (doc->annotation.defined()) { - if (!IsEmptyDoc(doc->annotation.value())) { - PrintDoc(doc->annotation.value(), false); - output_ << " "; - } - } - PrintDoc(doc->lhs, false); - if (doc->rhs.defined()) { - output_ << " = "; - EnterEndlineScope(false); - PrintDoc(doc->rhs.value(), false); - ExitEndlineScope(); - Endline(); - } -} - -void CppPrinter::PrintTypedDoc(const IfDoc& doc) { - MaybePrintComment(doc, true); - output_ << "if ("; - PrintDoc(doc->predicate, false); - output_ << ") {"; - PrintIndentedBlock(doc->then_branch); - if (!doc->else_branch.empty()) { - NewLine(); - output_ << "} else {"; - PrintIndentedBlock(doc->else_branch); - } - NewLine(); - output_ << "}"; -} - -void CppPrinter::PrintTypedDoc(const WhileDoc& doc) { - MaybePrintComment(doc, true); - output_ << "while ("; - PrintDoc(doc->predicate, false); - output_ << ") {"; - PrintIndentedBlock(doc->body); - NewLine(); - output_ << "}"; -} - -void CppPrinter::PrintTypedDoc(const ForDoc& doc) { - MaybePrintComment(doc, true); - if (doc->rhs->IsInstance()) { - const auto& tuple = Downcast(doc->rhs); - TVM_FFI_ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; - output_ << "for (size_t "; - PrintDoc(doc->lhs, false); - output_ << " = "; - PrintDoc(tuple->elements[0], false); - output_ << "; "; - PrintDoc(doc->lhs, false); - output_ << " < "; - PrintDoc(tuple->elements[1], false); - output_ << "; "; - PrintDoc(doc->lhs, false); - output_ << "++"; - } else { - output_ << "for (const auto& "; - PrintDoc(doc->lhs, false); - output_ << " : "; - PrintDoc(doc->rhs, false); - } - output_ << ") {"; - PrintIndentedBlock(doc->body); - NewLine(); - output_ << "}"; -} - -void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) { - MaybePrintComment(doc, true); - TVM_FFI_ICHECK(doc->rhs.defined()) << "rhs should be given for scope"; - PrintDoc(doc->rhs, false); - PrintIndentedBlock(doc->body); -} - -void CppPrinter::PrintTypedDoc(const FunctionDoc& doc) { - MaybePrintComment(doc, true); - for (const AssignDoc& arg_doc : doc->args) { - TVM_FFI_ICHECK(!arg_doc->comment.has_value()) - << "Function arg cannot have comment attached to them."; - } - if (doc->return_type.defined()) { - if (!IsEmptyDoc(doc->return_type.value())) { - PrintDoc(doc->return_type.value(), false); - output_ << " "; - } - } else { - output_ << "void "; - } - PrintDoc(doc->name, false); - output_ << "("; - PrintJoinedDocs(doc->args, ", "); - output_ << ")"; - if (doc->decorators.size() > 0) { - output_ << " "; - PrintJoinedDocs(doc->decorators, " "); - } - if (doc->body.size() > 0) { - output_ << " {"; - PrintIndentedBlock(doc->body); - if (doc->return_type.defined()) { - if (!IsEmptyDoc(doc->return_type.value())) { - Endline(); - } - } - NewLine(); - output_ << "}"; - } else { - Endline(); - } - NewLine(false); -} - -void CppPrinter::PrintTypedDoc(const ClassDoc& doc) { - MaybePrintComment(doc, true); - output_ << "class "; - PrintDoc(doc->name, false); - output_ << " {"; - for (const StmtDoc& d : doc->body) { - PrintDoc(d); - } - output_ << "}"; - Endline(); - output_ << " // class "; - PrintDoc(doc->name, false); - NewLine(false); -} - -void CppPrinter::PrintTypedDoc(const CommentDoc& doc) { - if (doc->comment.has_value()) { - output_ << "// " << doc->comment.value(); - } -} - -void CppPrinter::PrintTypedDoc(const DeclareDoc& doc) { - if (doc->type.defined()) { - PrintDoc(doc->type.value(), false); - output_ << " "; - } - PrintDoc(doc->variable, false); - if (doc->init_args.size() > 0) { - if (doc->use_constructor) { - output_ << "("; - PrintJoinedDocs(doc->init_args, ", "); - output_ << ")"; - } else { - output_ << "{"; - PrintJoinedDocs(doc->init_args, ", "); - output_ << "}"; - } - } - Endline(); -} - -void CppPrinter::PrintTypedDoc(const PointerDoc& doc) { output_ << doc->name << "->"; } - -void CppPrinter::PrintTypedDoc(const StrictListDoc& doc) { - if (doc->allow_empty || doc->list->elements.size() > 0) { - PrintDoc(doc->list, false); - } else { - output_ << "{}"; - } -} - -void CppPrinter::PrintTypedDoc(const StructDoc& doc) { - MaybePrintComment(doc, true); - output_ << "struct "; - PrintDoc(doc->name, false); - output_ << " {"; - IncreaseIndent(); - for (const StmtDoc& d : doc->body) { - PrintDoc(d); - } - DecreaseIndent(); - NewLine(false); - output_ << "}"; - Endline(); - output_ << " // struct "; - PrintDoc(doc->name, false); - NewLine(false); -} - -void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) { - MaybePrintComment(doc, true); - for (const AssignDoc& arg_doc : doc->args) { - TVM_FFI_ICHECK(!arg_doc->comment.has_value()) - << "Constructor arg cannot have comment attached to them."; - } - PrintDoc(doc->name, false); - output_ << "("; - PrintJoinedDocs(doc->args, ", "); - output_ << ")"; - if (doc->body.size() > 0) { - output_ << " {"; - PrintIndentedBlock(doc->body); - NewLine(); - output_ << "}"; - } else { - Endline(); - } - NewLine(false); -} - -void CppPrinter::PrintTypedDoc(const LambdaDoc& doc) { - MaybePrintComment(doc, true); - for (const AssignDoc& arg_doc : doc->args) { - TVM_FFI_ICHECK(!arg_doc->comment.has_value()) - << "Function arg cannot have comment attached to them."; - } - output_ << "auto "; - PrintDoc(doc->name, false); - output_ << " = ["; - PrintJoinedDocs(doc->refs, ", "); - output_ << "]("; - PrintJoinedDocs(doc->args, ", "); - output_ << ")"; - if (doc->body.size() > 0) { - output_ << " {"; - PrintIndentedBlock(doc->body); - Endline(); - NewLine(); - output_ << "};"; - } else { - Endline(); - } - NewLine(false); -} - -void CppPrinter::PrintTypedDoc(const SwitchDoc& doc) { - MaybePrintComment(doc, true); - TVM_FFI_ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) - << "predicates " << doc->predicates.size() << " mismatch with branchs " - << doc->branchs.size(); - for (size_t i = 0; i < doc->predicates.size(); i++) { - if (i == 0) { - output_ << "if ("; - } else { - NewLine(); - output_ << "} else if ("; - } - PrintDoc(doc->predicates[i], false); - output_ << ") {"; - PrintIndentedBlock(doc->branchs[i]); - } - if (!doc->default_branch.empty()) { - NewLine(); - output_ << "} else {"; - PrintIndentedBlock(doc->default_branch); - } - NewLine(); - output_ << "}"; -} - -bool CppPrinter::IsEmptyDoc(const ExprDoc& doc) { - if (!doc->IsInstance()) { - return false; - } - const auto& id_doc = Downcast(doc); - return id_doc->name == DocSymbol::Empty(); -} - -void CppPrinter::PrintIndentedBlock(const ffi::Array& docs) { - IncreaseIndent(); - for (const StmtDoc& d : docs) { - PrintDoc(d); - } - DecreaseIndent(); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h deleted file mode 100644 index fa55b13ddcb9..000000000000 --- a/src/contrib/msc/core/printer/cpp_printer.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/cpp_printer.h - * \brief Cpp Printer. - */ - -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ - -#include -#include - -#include "../utils.h" -#include "msc_base_printer.h" -#include "print_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief CppPrinter change list of docs to cpp format - * \sa Doc - */ -class CppPrinter : public MSCBasePrinter { - public: - /*! - * \brief The constructor of PythonPrinter - * \param options the options for printer. - */ - explicit CppPrinter(const std::string& options = "") : MSCBasePrinter(options) { - endlines_.push_back(true); - } - - protected: - /*! * \brief Print a LiteralDoc to cpp format*/ - void PrintTypedDoc(const LiteralDoc& doc) final; - - /*! * \brief Print a IndexDoc to cpp format*/ - void PrintTypedDoc(const IndexDoc& doc) final; - - /*! * \brief Print a AttrAccessDoc to cpp format*/ - void PrintTypedDoc(const AttrAccessDoc& doc) final; - - /*! * \brief Print a CallDoc to cpp format*/ - void PrintTypedDoc(const CallDoc& doc) final; - - /*! * \brief Print a AssignDoc to cpp format*/ - void PrintTypedDoc(const AssignDoc& doc) final; - - /*! * \brief Print a IfDoc to cpp format*/ - void PrintTypedDoc(const IfDoc& doc) final; - - /*! * \brief Print a WhileDoc to cpp format*/ - void PrintTypedDoc(const WhileDoc& doc) final; - - /*! * \brief Print a ForDoc to cpp format*/ - void PrintTypedDoc(const ForDoc& doc) final; - - /*! * \brief Print a ScopeDoc to cpp format*/ - void PrintTypedDoc(const ScopeDoc& doc) final; - - /*! * \brief Print a FunctionDoc to cpp format*/ - void PrintTypedDoc(const FunctionDoc& doc) final; - - /*! * \brief Print a ClassDoc to cpp format*/ - void PrintTypedDoc(const ClassDoc& doc) final; - - /*! * \brief Print a CommentDoc to cpp format*/ - void PrintTypedDoc(const CommentDoc& doc) final; - - /*! * \brief Print a DeclareDoc to cpp format*/ - void PrintTypedDoc(const DeclareDoc& doc) final; - - /*! * \brief Print a PointerDoc to cpp format*/ - void PrintTypedDoc(const PointerDoc& doc) final; - - /*! * \brief Print a StrictListDoc to cpp format*/ - void PrintTypedDoc(const StrictListDoc& doc) final; - - /*! * \brief Print a StructDoc to cpp format*/ - void PrintTypedDoc(const StructDoc& doc) final; - - /*! * \brief Print a ConstructorDoc to cpp format*/ - void PrintTypedDoc(const ConstructorDoc& doc) final; - - /*! * \brief Print a LambdaDoc to cpp format*/ - void PrintTypedDoc(const LambdaDoc& doc) final; - - /*! * \brief Print a SwitchDoc to cpp format*/ - void PrintTypedDoc(const SwitchDoc& doc) final; - - private: - /*! \brief endline scopes*/ - std::vector endlines_; - - /*! \brief Enter a endline scope*/ - void EnterEndlineScope(bool endline = false) { endlines_.push_back(endline); } - - /*! \brief Exit a endline scope*/ - void ExitEndlineScope() { - TVM_FFI_ICHECK(endlines_.size() > 1) << "No endline scope found"; - endlines_.pop_back(); - } - - /*! \brief enable enbline*/ - void EnableEndline() { - TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; - endlines_[endlines_.size() - 1] = true; - } - - /*! \brief disable enbline*/ - void DisableEndline() { - TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; - endlines_[endlines_.size() - 1] = false; - } - - /*! \brief Print endline*/ - void Endline() { - TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; - if (endlines_[endlines_.size() - 1]) { - output_ << ";"; - } - } - - /*! \brief Check if the doc is empty doc*/ - bool IsEmptyDoc(const ExprDoc& doc); - - /*! \brief Print block with indent*/ - void PrintIndentedBlock(const ffi::Array& docs); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_CPP_PRINTER_H_ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc deleted file mode 100644 index aeecd79750ff..000000000000 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/msc_base_printer.cc - */ - -#include "msc_base_printer.h" - -#include -#include - -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -void MSCBasePrinter::PrintDoc(const Doc& doc, bool new_line) { - if (new_line) { - NewLine(); - lines_++; - } - if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else if (auto doc_node = doc.as()) { - PrintTypedDoc(doc_node.value()); - } else { - TVM_FFI_THROW(InternalError) << "Do not know how to print " << doc->GetTypeKey(); - throw; - } -} - -void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { - const Any& value = doc->value; - if (value == nullptr) { - output_ << "\"\""; - } else if (auto opt_int_imm = value.as()) { - IntImm int_imm = *std::move(opt_int_imm); - output_ << int_imm->value; - } else if (auto opt_float_imm = value.as()) { - FloatImm float_imm = *std::move(opt_float_imm); - output_.precision(config_.float_precision); - if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { - output_ << '"' << float_imm->value << '"'; - } else { - output_ << float_imm->value; - } - } else if (auto opt_str = value.as()) { - output_ << "\"" << tvm::support::StrEscape((*opt_str).data(), (*opt_str).size()) << "\""; - } else { - TVM_FFI_THROW(TypeError) << "Unsupported literal value type: " << value.GetTypeKey(); - } -} - -void MSCBasePrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } - -void MSCBasePrinter::PrintTypedDoc(const ListDoc& doc) { - output_ << "["; - PrintJoinedDocs(doc->elements); - output_ << "]"; -} - -void MSCBasePrinter::PrintTypedDoc(const TupleDoc& doc) { - output_ << "("; - if (doc->elements.size() == 1) { - PrintDoc(doc->elements[0]); - output_ << ","; - } else { - PrintJoinedDocs(doc->elements); - } - output_ << ")"; -} - -void MSCBasePrinter::PrintTypedDoc(const ReturnDoc& doc) { - output_ << "return "; - PrintDoc(doc->value, false); - MaybePrintComment(doc); -} - -void MSCBasePrinter::PrintTypedDoc(const StmtBlockDoc& doc) { - for (const StmtDoc& stmt : doc->stmts) { - NewLine(); - PrintDoc(stmt); - } -} - -void MSCBasePrinter::PrintTypedDoc(const ExprStmtDoc& doc) { - PrintDoc(doc->expr, false); - MaybePrintComment(doc); -} - -void MSCBasePrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { - if (stmt->comment.has_value()) { - if (multi_lines) { - for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) { - PrintDoc(CommentDoc(l)); - } - } else { - PrintDoc(CommentDoc(stmt->comment.value()), false); - } - } -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h deleted file mode 100644 index 048eb25f8c90..000000000000 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/msc_base_printer.h - * \brief Base Printer for all MSC printers. - */ -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_MSC_BASE_PRINTER_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_MSC_BASE_PRINTER_H_ - -#include -#include - -#include - -#include "../../../../../src/support/str_escape.h" -#include "msc_doc.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief MSCPrinterConfig is base for config class in MSC - * \sa Doc - */ -struct MSCPrinterConfig { - size_t indent{0}; - size_t float_precision{6}; - std::string indent_space{" "}; - std::string separator{", "}; - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("indent")); it != obj.end()) { - indent = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("float_precision")); it != obj.end()) { - float_precision = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("indent_space")); it != obj.end()) { - indent_space = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("separator")); it != obj.end()) { - separator = std::string((*it).second.cast()); - } - } -}; - -/*! - * \brief MSCBasePrinter is responsible for printing Doc tree into text format - * \sa Doc - */ -class MSCBasePrinter { - public: - /*! - * \brief The constructor of MSCBasePrinter - * \param options the options for printer. - */ - explicit MSCBasePrinter(const std::string& options = "") { - if (options.size() > 0) { - namespace json = ::tvm::ffi::json; - config_.Load(json::Parse(options).cast()); - } - indent_ = config_.indent; - } - - virtual ~MSCBasePrinter() = default; - - /*! - * \brief Append a doc into the final content - * \sa GetString - */ - void Append(const Doc& doc, bool new_line = true) { PrintDoc(doc, new_line); } - - /*! - * \brief Get the printed string of all Doc appended - * \sa Append - */ - ffi::String GetString() const { return output_.str(); } - - protected: - /*! \brief Print doc*/ - void PrintDoc(const Doc& doc, bool new_line = true); - - /*! \brief Virtual method to print a LiteralDoc*/ - virtual void PrintTypedDoc(const LiteralDoc& doc); - - /*! \brief Virtual method to print an IdDoc*/ - virtual void PrintTypedDoc(const IdDoc& doc); - - /*! \brief Virtual method to print a ListDoc*/ - virtual void PrintTypedDoc(const ListDoc& doc); - - /*! \brief Virtual method to print a TupleDoc*/ - virtual void PrintTypedDoc(const TupleDoc& doc); - - /*! \brief Virtual method to print a ReturnDoc*/ - virtual void PrintTypedDoc(const ReturnDoc& doc); - - /*! \brief Virtual method to print a StmtBlockDoc*/ - virtual void PrintTypedDoc(const StmtBlockDoc& doc); - - /*! \brief Virtual method to print a ExprStmtDoc*/ - virtual void PrintTypedDoc(const ExprStmtDoc& doc); - - /*! \brief Virtual method to print an IndexDoc*/ - virtual void PrintTypedDoc(const IndexDoc& doc) { - TVM_FFI_THROW(InternalError) << "Index is not implemented"; - } - - /*! \brief Virtual method to print a CallDoc*/ - virtual void PrintTypedDoc(const CallDoc& doc) { - TVM_FFI_THROW(InternalError) << "Call is not implemented"; - } - - /*! \brief Virtual method to print an AttrAccessDoc*/ - virtual void PrintTypedDoc(const AttrAccessDoc& doc) { - TVM_FFI_THROW(InternalError) << "AttrAccess is not implemented"; - } - - /*! \brief Virtual method to print a DictDoc*/ - virtual void PrintTypedDoc(const DictDoc& doc) { - TVM_FFI_THROW(InternalError) << "Dict is not implemented"; - } - - /*! \brief Virtual method to print a SliceDoc*/ - virtual void PrintTypedDoc(const SliceDoc& doc) { - TVM_FFI_THROW(InternalError) << "Slice is not implemented"; - } - - /*! \brief Virtual method to print an AssignDoc*/ - virtual void PrintTypedDoc(const AssignDoc& doc) { - TVM_FFI_THROW(InternalError) << "Assign is not implemented"; - } - - /*! \brief Virtual method to print an IfDoc*/ - virtual void PrintTypedDoc(const IfDoc& doc) { - TVM_FFI_THROW(InternalError) << "If is not implemented"; - } - - /*! \brief Virtual method to print a WhileDoc*/ - virtual void PrintTypedDoc(const WhileDoc& doc) { - TVM_FFI_THROW(InternalError) << "While is not implemented"; - } - - /*! \brief Virtual method to print a ForDoc*/ - virtual void PrintTypedDoc(const ForDoc& doc) { - TVM_FFI_THROW(InternalError) << "For is not implemented"; - } - - /*! \brief Virtual method to print a ScopeDoc*/ - virtual void PrintTypedDoc(const ScopeDoc& doc) { - TVM_FFI_THROW(InternalError) << "Scope is not implemented"; - } - - /*! \brief Virtual method to print an AssertDoc*/ - virtual void PrintTypedDoc(const AssertDoc& doc) { - TVM_FFI_THROW(InternalError) << "Assert is not implemented"; - } - - /*! \brief Virtual method to print a FunctionDoc*/ - virtual void PrintTypedDoc(const FunctionDoc& doc) { - TVM_FFI_THROW(InternalError) << "Function is not implemented"; - } - - /*! \brief Virtual method to print a ClassDoc*/ - virtual void PrintTypedDoc(const ClassDoc& doc) { - TVM_FFI_THROW(InternalError) << "Class is not implemented"; - } - - /*! \brief Virtual method to print a CommentDoc*/ - virtual void PrintTypedDoc(const CommentDoc& doc) { - TVM_FFI_THROW(InternalError) << "Comment is not implemented"; - } - - /*! \brief Virtual method to print a DeclareDoc*/ - virtual void PrintTypedDoc(const DeclareDoc& doc) { - TVM_FFI_THROW(InternalError) << "Declare is not implemented"; - } - - /*! \brief Virtual method to print a StrictListDoc*/ - virtual void PrintTypedDoc(const StrictListDoc& doc) { - TVM_FFI_THROW(InternalError) << "StrictList is not implemented"; - } - - /*! \brief Virtual method to print a PointerDoc*/ - virtual void PrintTypedDoc(const PointerDoc& doc) { - TVM_FFI_THROW(InternalError) << "PointerDoc is not implemented"; - } - - /*! \brief Virtual method to print a StructDoc*/ - virtual void PrintTypedDoc(const StructDoc& doc) { - TVM_FFI_THROW(InternalError) << "StructDoc is not implemented"; - } - - /*! \brief Virtual method to print a ConstructorDoc*/ - virtual void PrintTypedDoc(const ConstructorDoc& doc) { - TVM_FFI_THROW(InternalError) << "ConstructorDoc is not implemented"; - } - - /*! \brief Virtual method to print a SwitchDoc*/ - virtual void PrintTypedDoc(const SwitchDoc& doc) { - TVM_FFI_THROW(InternalError) << "SwitchDoc is not implemented"; - } - - /*! \brief Virtual method to print a LambdaDoc*/ - virtual void PrintTypedDoc(const LambdaDoc& doc) { - TVM_FFI_THROW(InternalError) << "LambdaDoc is not implemented"; - } - - /*! \brief Print docs to joined doc */ - template - void PrintJoinedDocs(const ffi::Array& docs, const ffi::String& separator = ", ") { - for (size_t i = 0; i < docs.size(); i++) { - PrintDoc(docs[i], false); - output_ << (i == docs.size() - 1 ? "" : separator); - } - } - - /*! \brief Print comment for stmt*/ - virtual void MaybePrintComment(const StmtDoc& stmt, bool multi_lines = false); - - /*! - * \brief Start line into the output stream. - * \sa output_ - */ - std::ostream& NewLine(bool with_indent = true) { - if (lines_ > 0) { - output_ << "\n"; - } - if (with_indent) { - for (size_t i = 0; i < indent_; i++) { - output_ << config_.indent_space; - } - } - return output_; - } - - /*! \brief Increase the indent level*/ - void IncreaseIndent() { indent_ += 1; } - - /*! \brief Decrease the indent level*/ - void DecreaseIndent() { - if (indent_ >= 1) { - indent_ -= 1; - } - } - - /*! \brief Get the output stream*/ - const MSCPrinterConfig config() { return config_; } - - /*! \brief The output stream of printer*/ - std::ostringstream output_; - - private: - /*! \brief The current level of indent */ - size_t indent_ = 0; - - /*! \brief The lines num */ - size_t lines_ = 0; - - /*! \brief The config for printer */ - MSCPrinterConfig config_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_MSC_BASE_PRINTER_H_ diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc deleted file mode 100644 index e1cae35be132..000000000000 --- a/src/contrib/msc/core/printer/msc_doc.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/msc_doc.cc - */ - -#include "msc_doc.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -DeclareDoc::DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, - bool use_constructor) { - ObjectPtr n = ffi::make_object(); - n->type = type; - n->variable = variable; - n->init_args = init_args; - n->use_constructor = use_constructor; - this->data_ = std::move(n); -} - -StrictListDoc::StrictListDoc(ListDoc list, bool allow_empty) { - ObjectPtr n = ffi::make_object(); - n->list = list; - n->allow_empty = allow_empty; - this->data_ = std::move(n); -} - -PointerDoc::PointerDoc(ffi::String name) { - ObjectPtr n = ffi::make_object(); - n->name = name; - this->data_ = std::move(n); -} - -StructDoc::StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body) { - ObjectPtr n = ffi::make_object(); - n->name = name; - n->decorators = decorators; - n->body = body; - this->data_ = std::move(n); -} - -ConstructorDoc::ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body) { - ObjectPtr n = ffi::make_object(); - n->name = name; - n->args = args; - n->body = body; - this->data_ = std::move(n); -} - -SwitchDoc::SwitchDoc(ffi::Array predicates, ffi::Array> branchs, - ffi::Array default_branch) { - ObjectPtr n = ffi::make_object(); - n->predicates = predicates; - n->branchs = branchs; - n->default_branch = default_branch; - this->data_ = std::move(n); -} - -LambdaDoc::LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, - ffi::Array body) { - ObjectPtr n = ffi::make_object(); - n->name = name; - n->args = args; - n->refs = refs; - n->body = body; - this->data_ = std::move(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - DeclareDocNode::RegisterReflection(); - StrictListDocNode::RegisterReflection(); - PointerDocNode::RegisterReflection(); - StructDocNode::RegisterReflection(); - ConstructorDocNode::RegisterReflection(); - SwitchDocNode::RegisterReflection(); - LambdaDocNode::RegisterReflection(); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h deleted file mode 100644 index fe0f6c68338f..000000000000 --- a/src/contrib/msc/core/printer/msc_doc.h +++ /dev/null @@ -1,340 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/msc_doc.h - * \brief Extra docs for MSC - */ -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ - -#include -#include - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief Doc that declare a var with type. - * - * \sa DeclareDoc - */ -class DeclareDocNode : public ExprDocNode { - public: - /*! \brief The type of the variable */ - ffi::Optional type; - /*! \brief The variable */ - ExprDoc variable{ffi::UnsafeInit{}}; - /*! \brief The init arguments for the variable. */ - ffi::Array init_args; - /*! \brief Whether to use constructor(otherwise initializer) */ - bool use_constructor{true}; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("type", &DeclareDocNode::type) - .def_ro("variable", &DeclareDocNode::variable) - .def_ro("init_args", &DeclareDocNode::init_args) - .def_ro("use_constructor", &DeclareDocNode::use_constructor); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.DeclareDoc", DeclareDocNode, ExprDocNode); -}; - -/*! - * \brief Reference type of DeclareDocNode. - * - * \sa DeclareDocNode - */ -class DeclareDoc : public ExprDoc { - public: - /*! - * \brief Constructor of DeclareDoc. - * \param type The type of the variable. - * \param variable The variable. - * \param init_args The init arguments of the variable. - * \param use_constructor Whether to use constructor(otherwise initializer). - */ - explicit DeclareDoc(ffi::Optional type, ExprDoc variable, ffi::Array init_args, - bool use_constructor); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(DeclareDoc, ExprDoc, DeclareDocNode); -}; - -/*! - * \brief Doc that build a strict list, which check the empty. - * - * \sa StrictListDoc - */ -class StrictListDocNode : public ExprDocNode { - public: - /*! \brief The inner list doc */ - ListDoc list; - /*! \brief Whether to allow empty */ - bool allow_empty{true}; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("list", &StrictListDocNode::list) - .def_ro("allow_empty", &StrictListDocNode::allow_empty); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.StrictListDoc", StrictListDocNode, - ExprDocNode); -}; - -/*! - * \brief Reference type of StrictListDocNode. - * - * \sa StrictListDocNode - */ -class StrictListDoc : public ExprDoc { - public: - /*! - * \brief Constructor of StrictListDoc. - * \param list The inner list doc. - * \param allow_empty Whether to allow empty. - */ - explicit StrictListDoc(ListDoc list, bool allow_empty); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StrictListDoc, ExprDoc, StrictListDocNode); -}; - -/*! - * \brief Doc that represents pointer. - * - * \sa PointerDoc - */ -class PointerDocNode : public ExprDocNode { - public: - /*! \brief The name of the identifier */ - ffi::String name; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name", &PointerDocNode::name); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.PointerDoc", PointerDocNode, ExprDocNode); -}; - -/*! - * \brief Reference type of PointerDocNode. - * - * \sa PointerDocNode - */ -class PointerDoc : public ExprDoc { - public: - /*! - * \brief Constructor of PointerDoc. - * \param name The name of identifier. - */ - explicit PointerDoc(ffi::String name); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PointerDoc, ExprDoc, PointerDocNode); -}; - -/*! - * \brief Doc that represents struct definition. - * - * \sa StructDoc - */ -class StructDocNode : public StmtDocNode { - public: - /*! \brief The name of class. */ - IdDoc name{ffi::UnsafeInit{}}; - /*! \brief Decorators of class. */ - ffi::Array decorators; - /*! \brief The body of class. */ - ffi::Array body; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &StructDocNode::name) - .def_ro("decorators", &StructDocNode::decorators) - .def_ro("body", &StructDocNode::body); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.StructDoc", StructDocNode, StmtDocNode); -}; - -/*! - * \brief Reference type of StructDocNode. - * - * \sa StructDocNode - */ -class StructDoc : public StmtDoc { - public: - /*! - * \brief Constructor of StructDoc. - * \param name The name of class. - * \param decorators The decorator of class. - * \param body The body of class. - */ - explicit StructDoc(IdDoc name, ffi::Array decorators, ffi::Array body); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(StructDoc, StmtDoc, StructDocNode); -}; - -/*! - * \brief Doc that represents constructor definition. - * - * \sa ConstructorDoc - */ -class ConstructorDocNode : public StmtDocNode { - public: - /*! \brief The name of function. */ - IdDoc name{ffi::UnsafeInit{}}; - /*! - * \brief The arguments of function. - * - * The `lhs` means argument name, - * `annotation` means argument type, - * and `rhs` means default value. - */ - ffi::Array args; - /*! \brief The body of function. */ - ffi::Array body; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &ConstructorDocNode::name) - .def_ro("args", &ConstructorDocNode::args) - .def_ro("body", &ConstructorDocNode::body); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.ConstructorDoc", ConstructorDocNode, - StmtDocNode); -}; - -/*! - * \brief Reference type of ConstructorDocNode. - * - * \sa ConstructorDocNode - */ -class ConstructorDoc : public StmtDoc { - public: - /*! - * \brief Constructor of ConstructorDoc. - * \param name The name of function.. - * \param args The arguments of function. - * \param body The body of function. - */ - explicit ConstructorDoc(IdDoc name, ffi::Array args, ffi::Array body); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ConstructorDoc, StmtDoc, ConstructorDocNode); -}; - -/*! - * \brief Doc that represent switch statement. - * - * \sa SwitchDoc - */ -class SwitchDocNode : public StmtDocNode { - public: - /*! \brief The predicates of the switch statement. */ - ffi::Array predicates; - /*! \brief The branchs of the switch statement. */ - ffi::Array> branchs; - /*! \brief The default_branch of the switch statement. */ - ffi::Array default_branch; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("predicates", &SwitchDocNode::predicates) - .def_ro("branchs", &SwitchDocNode::branchs) - .def_ro("default_branch", &SwitchDocNode::default_branch); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.SwitchDoc", SwitchDocNode, StmtDocNode); -}; - -/*! - * \brief Reference type of SwitchDocNode. - * - * \sa SwitchDocNode - */ -class SwitchDoc : public StmtDoc { - public: - /*! - * \brief Constructor of SwitchDoc. - * \param predicates The predicates of the switch statement. - * \param branchs The branchs of the switch statement. - * \param default_branch The default_branch of the switch statement. - */ - explicit SwitchDoc(ffi::Array predicates, ffi::Array> branchs, - ffi::Array default_branch); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SwitchDoc, StmtDoc, SwitchDocNode); -}; - -/*! - * \brief Doc that represents lambda definition. - * - * \sa LambdaDoc - */ -class LambdaDocNode : public StmtDocNode { - public: - /*! \brief The name of lambda. */ - IdDoc name{ffi::UnsafeInit{}}; - /*! - * \brief The arguments of lambda. - * - * The `lhs` means argument name, - * `annotation` means argument type, - * and `rhs` means default value. - */ - ffi::Array args; - /*! \brief References of lambda. */ - ffi::Array refs; - /*! \brief The body of lambda. */ - ffi::Array body; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("name", &LambdaDocNode::name) - .def_ro("args", &LambdaDocNode::args) - .def_ro("refs", &LambdaDocNode::refs) - .def_ro("body", &LambdaDocNode::body); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("msc.script.printer.LambdaDoc", LambdaDocNode, StmtDocNode); -}; - -/*! - * \brief Reference type of LambdaDocNode. - * - * \sa LambdaDoc - */ -class LambdaDoc : public StmtDoc { - public: - /*! - * \brief Constructor of LambdaDoc. - * \param name The name of lambda. - * \param args The arguments of lambda. - * \param refs The references of lambda. - * \param body The body of lambda. - */ - explicit LambdaDoc(IdDoc name, ffi::Array args, ffi::Array refs, - ffi::Array body); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LambdaDoc, StmtDoc, LambdaDocNode); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc deleted file mode 100644 index e6ab2b28c152..000000000000 --- a/src/contrib/msc/core/printer/print_utils.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/print_utils.cc - */ -#include "print_utils.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::String DocSymbol::Empty() { return "::EMPTY"; } - -const ffi::String DocSymbol::NextLine() { return "::NEXT_LINE"; } - -const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, std::nullopt); } - -const ExprDoc DocUtils::ToDoc(int val) { return ToDoc(static_cast(val)); } - -const ExprDoc DocUtils::ToDoc(size_t val) { return ToDoc(static_cast(val)); } - -const ExprDoc DocUtils::ToDoc(const IntImm& val) { return ToDoc(val->value); } - -const ExprDoc DocUtils::ToDoc(const Integer& val) { return ToDoc(val->value); } - -const ExprDoc DocUtils::ToDoc(double val) { return LiteralDoc::Float(val, std::nullopt); } - -const ExprDoc DocUtils::ToDoc(float val) { return ToDoc(static_cast(val)); } - -const ExprDoc DocUtils::ToDoc(const FloatImm& val) { return ToDoc(val->value); } - -const ExprDoc DocUtils::ToDoc(const char* val) { return IdDoc(std::string(val)); } - -const ExprDoc DocUtils::ToDoc(const ffi::String& val) { return IdDoc(val); } - -const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, std::nullopt); } - -const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } - -const ExprDoc DocUtils::ToStr(const ffi::String& val) { return LiteralDoc::Str(val, std::nullopt); } - -const PointerDoc DocUtils::ToPtr(const ffi::String& val) { return PointerDoc(val); } - -const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { - if (values.size() > 0 || allow_empty) { - ffi::Array elements; - for (const auto& v : values) { - elements.push_back(ToStr(v)); - } - return StrictListDoc(ListDoc(elements), allow_empty); - } - return StrictListDoc(ListDoc(), false); -} - -const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { - std::vector v_values; - for (const auto& v : values) { - v_values.push_back(v); - } - return ToStrList(v_values, allow_empty); -} - -const StrictListDoc DocUtils::ToStrList(const ffi::Array& values, bool allow_empty) { - std::vector v_values; - for (const auto& v : values) { - v_values.push_back(v); - } - return ToStrList(v_values, allow_empty); -} - -const ffi::Array DocUtils::ToStmts(const ffi::Array& docs) { - ffi::Array stmts; - for (const auto& d : docs) { - if (d->IsInstance()) { - stmts.push_back(Downcast(d)); - } else if (d->IsInstance()) { - stmts.push_back(ExprStmtDoc(Downcast(d))); - } else { - TVM_FFI_THROW(InternalError) << "Unecpected doc type " << d->GetTypeKey(); - } - } - return stmts; -} - -const StmtBlockDoc DocUtils::ToStmtBlock(const ffi::Array& docs) { - return StmtBlockDoc(ToStmts(docs)); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h deleted file mode 100644 index 3ccc1cdc22cc..000000000000 --- a/src/contrib/msc/core/printer/print_utils.h +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/print_utils.h - * \brief Common utilities for print. - */ -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_PRINT_UTILS_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_PRINT_UTILS_H_ - -#include - -#include -#include - -#include "msc_doc.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief Symbols for Doc. - */ - -class DocSymbol { - public: - /*! * \brief The empty symbol*/ - TVM_DLL static const ffi::String Empty(); - - /*! * \brief The next line symbol*/ - TVM_DLL static const ffi::String NextLine(); -}; - -/*! - * \brief Utils for Doc. - */ -class DocUtils { - public: - /*! - * \brief Change object to Doc. - * \return The Doc. - */ - TVM_DLL static const ExprDoc ToDoc(int val); - TVM_DLL static const ExprDoc ToDoc(int64_t val); - TVM_DLL static const ExprDoc ToDoc(size_t val); - TVM_DLL static const ExprDoc ToDoc(const IntImm& val); - TVM_DLL static const ExprDoc ToDoc(const Integer& val); - TVM_DLL static const ExprDoc ToDoc(float val); - TVM_DLL static const ExprDoc ToDoc(double val); - TVM_DLL static const ExprDoc ToDoc(const FloatImm& val); - TVM_DLL static const ExprDoc ToDoc(const char* val); - TVM_DLL static const ExprDoc ToDoc(const ffi::String& val); - TVM_DLL static const ExprDoc ToDoc(bool val); - TVM_DLL static const ExprDoc ToDoc(const ExprDoc& val); - TVM_DLL static const ExprDoc ToStr(const ffi::String& val); - TVM_DLL static const PointerDoc ToPtr(const ffi::String& val); - - /*! - * \brief Change object to DeclareDoc. - * \return The DeclareDoc. - */ - template - TVM_DLL static const DeclareDoc ToDeclare(const ffi::String& type, const T& variable, - size_t len = 0, bool use_constructor = true) { - ffi::Optional type_doc; - if (type.size() == 0) { - type_doc = std::nullopt; - } else { - type_doc = IdDoc(type); - } - if (len == 0) { - return DeclareDoc(type_doc, ToDoc(variable), ffi::Array(), use_constructor); - } - ffi::Array doc_indices{DocUtils::ToDoc(len)}; - return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), ffi::Array(), - use_constructor); - } - - /*! - * \brief Change object to AssignDoc. - * \return The AssignDoc. - */ - template - TVM_DLL static const AssignDoc ToAssign(const LT& lhs, const RT& rhs, - const ffi::String& annotation = "") { - if (annotation.size() == 0) { - return AssignDoc(ToDoc(lhs), ToDoc(rhs), std::nullopt); - } - return AssignDoc(ToDoc(lhs), ToDoc(rhs), IdDoc(annotation)); - } - template - TVM_DLL static const AssignDoc ToAssign(const T& lhs, const ffi::String& rhs, - const ffi::String& annotation = "") { - ffi::Optional rhs_doc; - if (rhs.size() > 0) { - rhs_doc = IdDoc(rhs); - } else { - rhs_doc = std::nullopt; - } - ffi::Optional annotation_doc; - if (annotation.size() > 0) { - annotation_doc = IdDoc(annotation); - } else { - annotation_doc = std::nullopt; - } - return AssignDoc(ToDoc(lhs), rhs_doc, annotation_doc); - } - - /*! - * \brief Change object to AttrAccessDoc. - * \return The AttrAccessDoc. - */ - template - TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const ffi::String& name) { - return AttrAccessDoc(ToDoc(value), name); - } - - /*! - * \brief Change object to List of Docs. - * \return The List of Docs. - */ - template - TVM_DLL static const ffi::Array ToDocList(const std::vector& values) { - ffi::Array elements; - for (const auto& v : values) { - elements.push_back(ToDoc(v)); - } - return elements; - } - template - TVM_DLL static const ffi::Array ToDocList(const ffi::Array& values) { - std::vector v_values; - for (const auto& v : values) { - v_values.push_back(v); - } - return ToDocList(v_values); - } - - /*! - * \brief Change object to ListDoc. - * \return The ListDoc. - */ - template - TVM_DLL static const StrictListDoc ToList(const std::vector& values, - bool allow_empty = false) { - if (values.size() > 0 || allow_empty) { - return StrictListDoc(ListDoc(ToDocList(values)), allow_empty); - } - return StrictListDoc(ListDoc(), false); - } - template - TVM_DLL static const StrictListDoc ToList(const ffi::Array& values, bool allow_empty = false) { - std::vector v_values; - for (const auto& v : values) { - v_values.push_back(v); - } - return ToList(v_values, allow_empty); - } - - /*! - * \brief Change object to ListDoc for string elemenets. - * \return The ListDoc. - */ - TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, - bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, - bool allow_empty = false); - TVM_DLL static const StrictListDoc ToStrList(const ffi::Array& values, - bool allow_empty = false); - - /*! - * \brief Change object to IndexDoc. - * \return The IndexDoc. - */ - template - TVM_DLL static const IndexDoc ToIndex(const VT& value, const IT& index) { - ffi::Array doc_indices; - doc_indices.push_back(ToDoc(index)); - return IndexDoc(ToDoc(value), doc_indices); - } - template - TVM_DLL static const IndexDoc ToIndices(const VT& value, const std::vector& indices) { - ffi::Array doc_indices; - for (const auto& i : indices) { - doc_indices.push_back(ToDoc(i)); - } - return IndexDoc(ToDoc(value), doc_indices); - } - template - TVM_DLL static const IndexDoc ToIndices(const VT& value, const ffi::Array& indices) { - ffi::Array doc_indices; - for (const auto& i : indices) { - doc_indices.push_back(ToDoc(i)); - } - return IndexDoc(ToDoc(value), doc_indices); - } - - /*! - * \brief Convert the docs to Stmts. - * \return The Stmts. - */ - TVM_DLL static const ffi::Array ToStmts(const ffi::Array& docs); - - /*! - * \brief Convert the docs to StmtBlock. - * \return The StmtBlockDoc. - */ - TVM_DLL static const StmtBlockDoc ToStmtBlock(const ffi::Array& docs); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_PRINT_UTILS_H_ diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc deleted file mode 100644 index 299712ce9adc..000000000000 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/prototxt_printer.cc - */ - -#include "prototxt_printer.h" - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) { - if (auto opt_str = obj.as()) { - return LiteralDoc::Str(*opt_str, std::nullopt); - } else if (obj.as()) { - return LiteralDoc::Int(Downcast(obj)->value, std::nullopt); - } else if (obj.as()) { - return LiteralDoc::Float(Downcast(obj)->value, std::nullopt); - } - std::ostringstream obj_des; - obj_des << obj; - return LiteralDoc::Str(obj_des.str(), std::nullopt); -} - -DictDoc PrototxtPrinter::ToDictDoc(const ffi::Map& dict) { - ffi::Array keys; - ffi::Array values; - for (const auto& pair : dict) { - keys.push_back(IdDoc(pair.first)); - if (pair.second.as()) { - values.push_back(Downcast(pair.second)); - } else { - values.push_back(ToLiteralDoc(pair.second)); - } - } - return DictDoc(keys, values); -} - -DictDoc PrototxtPrinter::ToDictDoc(const std::vector>& dict) { - ffi::Array keys; - ffi::Array values; - for (const auto& pair : dict) { - keys.push_back(IdDoc(pair.first)); - if (pair.second.as()) { - values.push_back(Downcast(pair.second)); - } else { - values.push_back(ToLiteralDoc(pair.second)); - } - } - return DictDoc(keys, values); -} - -void PrototxtPrinter::Append(const ffi::Map& dict) { - DictDoc doc = ToDictDoc(dict); - PrintDoc(doc, false); -} - -void PrototxtPrinter::Append(const std::vector>& dict) { - DictDoc doc = ToDictDoc(dict); - PrintDoc(doc, false); -} - -void PrototxtPrinter::AppendPair(const ffi::String& key, const ffi::Any& value) { - ffi::Map dict; - dict.Set(key, value); - return Append(dict); -} - -void PrototxtPrinter::PrintTypedDoc(const DictDoc& doc) { - TVM_FFI_ICHECK_EQ(doc->keys.size(), doc->values.size()) - << "DictDoc should have equal number of elements in keys and values."; - for (size_t i = 0; i < doc->keys.size(); i++) { - TVM_FFI_ICHECK(doc->keys[i].as()) - << "Prototxt key should be IdDoc, get " << doc->keys[i]->GetTypeKey(); - PrintDoc(doc->keys[i]); - if (doc->values[i].as()) { - output_ << " {"; - IncreaseIndent(); - PrintDoc(doc->values[i], false); - DecreaseIndent(); - NewLine() << "}"; - } else { - output_ << ": "; - PrintDoc(doc->values[i], false); - } - } -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/prototxt_printer.h b/src/contrib/msc/core/printer/prototxt_printer.h deleted file mode 100644 index f304dcdd5819..000000000000 --- a/src/contrib/msc/core/printer/prototxt_printer.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/prototxt_printer.h - * \brief Prototxt Printer. - */ - -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_PROTOTXT_PRINTER_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_PROTOTXT_PRINTER_H_ - -#include -#include -#include - -#include "msc_base_printer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief PrototxtPrinter change list of dict to prototxt format - * \sa Doc - */ -class PrototxtPrinter : public MSCBasePrinter { - public: - /*! - * \brief The constructor of PrototxtPrinter - * \param options the options for printer. - */ - explicit PrototxtPrinter(const std::string& options = "") : MSCBasePrinter(options) {} - - /*! \brief Change object to LiteralDoc*/ - static LiteralDoc ToLiteralDoc(const ffi::Any& obj); - - /*! \brief Change map to DictDoc*/ - static DictDoc ToDictDoc(const ffi::Map& dict); - - /*! \brief Change ordered pairs to DictDoc*/ - static DictDoc ToDictDoc(const std::vector>& dict); - - /*! \brief Append a map into the final content*/ - void Append(const ffi::Map& dict); - - /*! \brief Append ordered pairs into the final content*/ - void Append(const std::vector>& dict); - - /*! \brief Append a map pair into the final content*/ - void AppendPair(const ffi::String& key, const ffi::Any& value); - - protected: - /*! * \brief Print a DictDoc to prototxt format*/ - void PrintTypedDoc(const DictDoc& doc) final; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_PROTOTXT_PRINTER_H_ diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc deleted file mode 100644 index 3966d8b3e5fe..000000000000 --- a/src/contrib/msc/core/printer/python_printer.cc +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/python_printer.cc - */ - -#include "python_printer.h" - -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -void PythonPrinter::PrintTypedDoc(const LiteralDoc& doc) { - const ffi::Any& value = doc->value; - bool defined = false; - if (value == nullptr) { - output_ << "None"; - defined = true; - } else if (const auto* int_imm = value.as()) { - if (int_imm->dtype.is_bool()) { - output_ << (int_imm->value ? "True" : "False"); - defined = true; - } - } - if (!defined) { - MSCBasePrinter::PrintTypedDoc(doc); - } -} - -void PythonPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { - PrintDoc(doc->value, false); - output_ << "." << doc->name; -} - -void PythonPrinter::PrintTypedDoc(const IndexDoc& doc) { - PrintDoc(doc->value, false); - if (doc->indices.size() == 0) { - output_ << "[()]"; - } else { - output_ << "["; - PrintJoinedDocs(doc->indices, ", "); - output_ << "]"; - } -} - -void PythonPrinter::PrintTypedDoc(const CallDoc& doc) { - PrintDoc(doc->callee, false); - output_ << "("; - PrintJoinedDocs(doc->args); - TVM_FFI_ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) - << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; - if (doc->args.size() > 0 && doc->kwargs_keys.size() > 0) { - output_ << ", "; - } - for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { - output_ << doc->kwargs_keys[i] << "="; - PrintDoc(doc->kwargs_values[i], false); - output_ << (i == doc->kwargs_keys.size() - 1 ? "" : ", "); - } - output_ << ")"; -} - -void PythonPrinter::PrintTypedDoc(const AssignDoc& doc) { - if (const auto* tuple_doc = doc->lhs.as()) { - PrintJoinedDocs(tuple_doc->elements, ", "); - } else { - PrintDoc(doc->lhs, false); - } - - if (doc->annotation) { - output_ << ": "; - PrintDoc(doc->annotation.value(), false); - } - if (doc->rhs) { - output_ << " = "; - if (const auto* tuple_doc = doc->rhs.as()) { - if (tuple_doc->elements.size() > 1) { - PrintJoinedDocs(tuple_doc->elements, ", "); - } else { - PrintDoc(doc->rhs.value(), false); - } - } else { - PrintDoc(doc->rhs.value(), false); - } - } - MaybePrintComment(doc); -} - -void PythonPrinter::PrintTypedDoc(const IfDoc& doc) { - MaybePrintComment(doc, true); - output_ << "if "; - PrintDoc(doc->predicate, false); - output_ << ":"; - - PrintIndentedBlock(doc->then_branch); - - if (!doc->else_branch.empty()) { - NewLine(); - output_ << "else:"; - PrintIndentedBlock(doc->else_branch); - } -} - -void PythonPrinter::PrintTypedDoc(const ForDoc& doc) { - MaybePrintComment(doc, true); - if (doc->rhs->IsInstance()) { - const auto& tuple = Downcast(doc->rhs); - TVM_FFI_ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; - output_ << "for "; - PrintDoc(doc->lhs, false); - output_ << " in range("; - PrintDoc(tuple->elements[0], false); - output_ << ", "; - PrintDoc(tuple->elements[1], false); - output_ << "):"; - } else { - output_ << "for "; - PrintDoc(doc->lhs, false); - output_ << " in "; - PrintDoc(doc->rhs, false); - output_ << ":"; - } - PrintIndentedBlock(doc->body); -} - -void PythonPrinter::PrintTypedDoc(const ScopeDoc& doc) { - MaybePrintComment(doc, true); - output_ << "with "; - PrintDoc(doc->rhs, false); - if (doc->lhs != nullptr) { - output_ << " as "; - PrintDoc(doc->lhs.value(), false); - } - output_ << ":"; - - PrintIndentedBlock(doc->body); -} - -void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) { - for (const AssignDoc& arg_doc : doc->args) { - TVM_FFI_ICHECK(!arg_doc->comment.has_value()) - << "Function arg cannot have comment attached to them."; - } - - PrintDecorators(doc->decorators); - - output_ << "def "; - PrintDoc(doc->name, false); - - output_ << "("; - PrintJoinedDocs(doc->args, ", "); - output_ << ")"; - - if (doc->return_type.defined()) { - output_ << " -> "; - PrintDoc(doc->return_type.value(), false); - } - - output_ << ":"; - - if (doc->comment.has_value()) { - IncreaseIndent(); - MaybePrintComment(doc, true); - DecreaseIndent(); - } - PrintIndentedBlock(doc->body); - NewLine(false); -} - -void PythonPrinter::PrintTypedDoc(const ClassDoc& doc) { - PrintDecorators(doc->decorators); - - output_ << "class "; - PrintDoc(doc->name, false); - output_ << ":"; - - MaybePrintComment(doc, true); - PrintIndentedBlock(doc->body); -} - -void PythonPrinter::PrintTypedDoc(const CommentDoc& doc) { - if (doc->comment.has_value()) { - output_ << "# " << doc->comment.value(); - } -} - -void PythonPrinter::PrintTypedDoc(const StrictListDoc& doc) { - if (doc->allow_empty || doc->list->elements.size() > 0) { - PrintDoc(doc->list, false); - } else { - output_ << "None"; - } -} - -void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) { - MaybePrintComment(doc, true); - TVM_FFI_ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) - << "predicates " << doc->predicates.size() << " mismatch with branchs " - << doc->branchs.size(); - for (size_t i = 0; i < doc->predicates.size(); i++) { - if (i == 0) { - output_ << "if "; - } else { - NewLine(); - output_ << "elif "; - } - PrintDoc(doc->predicates[i], false); - output_ << ":"; - PrintIndentedBlock(doc->branchs[i]); - } - if (!doc->default_branch.empty()) { - NewLine(); - output_ << "else:"; - PrintIndentedBlock(doc->default_branch); - } -} - -void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { - if (stmt->comment.has_value() && multi_lines) { - NewLine(); - output_ << "\"\"\""; - for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) { - PrintDoc(IdDoc(l)); - } - NewLine(); - output_ << "\"\"\""; - NewLine(); - } else { - MSCBasePrinter::MaybePrintComment(stmt, multi_lines); - } -} - -void PythonPrinter::PrintIndentedBlock(const ffi::Array& docs) { - IncreaseIndent(); - for (const StmtDoc& d : docs) { - PrintDoc(d); - } - if (docs.empty()) { - NewLine() << "pass"; - } - DecreaseIndent(); -} - -void PythonPrinter::PrintDecorators(const ffi::Array& decorators) { - for (const ExprDoc& decorator : decorators) { - output_ << "@"; - PrintDoc(decorator, false); - NewLine(); - } -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/printer/python_printer.h b/src/contrib/msc/core/printer/python_printer.h deleted file mode 100644 index 3e09b1fcdabc..000000000000 --- a/src/contrib/msc/core/printer/python_printer.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/printer/python_printer.h - * \brief Python Printer. - */ - -#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_PYTHON_PRINTER_H_ -#define TVM_CONTRIB_MSC_CORE_PRINTER_PYTHON_PRINTER_H_ - -#include - -#include "msc_base_printer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief PythonPrinter change list of docs to python format - * \sa Doc - */ -class PythonPrinter : public MSCBasePrinter { - public: - /*! - * \brief The constructor of PythonPrinter - * \param options the options for printer. - */ - explicit PythonPrinter(const std::string& options = "") : MSCBasePrinter(options) {} - - protected: - /*! * \brief Print a LiteralDoc to python format*/ - void PrintTypedDoc(const LiteralDoc& doc) final; - - /*! * \brief Print a AttrAccessDoc to python format*/ - void PrintTypedDoc(const AttrAccessDoc& doc) final; - - /*! * \brief Print a IndexDoc to python format*/ - void PrintTypedDoc(const IndexDoc& doc) final; - - /*! * \brief Print a CallDoc to python format*/ - void PrintTypedDoc(const CallDoc& doc) final; - - /*! * \brief Print a AssignDoc to python format*/ - void PrintTypedDoc(const AssignDoc& doc) final; - - /*! * \brief Print a IfDoc to python format*/ - void PrintTypedDoc(const IfDoc& doc) final; - - /*! * \brief Print a ForDoc to python format*/ - void PrintTypedDoc(const ForDoc& doc) final; - - /*! * \brief Print a ScopeDoc to python format*/ - void PrintTypedDoc(const ScopeDoc& doc) final; - - /*! * \brief Print a FunctionDoc to python format*/ - void PrintTypedDoc(const FunctionDoc& doc) final; - - /*! * \brief Print a ClassDoc to python format*/ - void PrintTypedDoc(const ClassDoc& doc) final; - - /*! * \brief Print a CommentDoc to python format*/ - void PrintTypedDoc(const CommentDoc& doc) final; - - /*! * \brief Print a StrictListDoc to python format*/ - void PrintTypedDoc(const StrictListDoc& doc) final; - - /*! * \brief Print a SwitchDoc to python format*/ - void PrintTypedDoc(const SwitchDoc& doc) final; - - /*! \brief Print comment for stmt in python format*/ - void MaybePrintComment(const StmtDoc& stmt, bool multi_lines = false) final; - - private: - /*! \brief Print block with indent*/ - void PrintIndentedBlock(const ffi::Array& docs); - - /*! \brief Print decorators for function and class*/ - void PrintDecorators(const ffi::Array& decorators); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_CORE_PRINTER_PYTHON_PRINTER_H_ diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc deleted file mode 100644 index 08cff58e68b6..000000000000 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "../utils.h" - -namespace tvm { -namespace relax { -using namespace tvm::contrib::msc; - -std::tuple, ffi::Map> NormalizeNamedBindings( - const Function& func, const ffi::Map& untyped_params) { - TVM_FFI_ICHECK(func.defined()); - TVM_FFI_ICHECK(untyped_params.defined()); - - // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; - std::unordered_set var_set; - for (const auto& param : func->params) { - string_lookup[param->name_hint()].push_back(param); - var_set.insert(param.get()); - } - - ffi::Map relax_var_remap; - - auto normalize_key = [&](ffi::Any obj) -> relax::Var { - if (auto opt_str = obj.as()) { - std::string str = opt_str.value(); - auto it = string_lookup.find(str); - TVM_FFI_ICHECK(it != string_lookup.end()) - << "Function does not have parameter with name \"" << str << "\". " - << "Function parameters are named " - << func->params.Map([](const auto& param) { return param->name_hint(); }); - TVM_FFI_ICHECK_EQ(it->second.size(), 1) - << "Function contains multiple parameters with name \"" << str << "\". " - << "The Relax variables " << it->second << " are all named \"" << str << "\""; - auto var = it->second[0]; - TVM_FFI_ICHECK(!relax_var_remap.count(var)) - << "Remap of variable " << var << " was defined multiple times"; - - return var; - } else if (auto opt_var = obj.as()) { - auto var = opt_var.value(); - TVM_FFI_ICHECK(!relax_var_remap.count(var)) - << "Remap of variable " << var << " was defined multiple times"; - TVM_FFI_ICHECK(var_set.count(var.get())) - << "Function does not use Relax variable " << var << " as a parameter. " - << "Function parameters are " << func->params; - return var; - } else { - TVM_FFI_THROW(InternalError) - << "Expected bound parameter to be a relax::Var, " - << " or a string that uniquely identifies a relax::Var param within the function. " - << "However, received object " << obj << " of type " << obj.GetTypeKey(); - } - }; - auto normalize_value = [&](Var key, ffi::Any obj) -> relax::Expr { - if (auto opt = obj.as()) { - return opt.value(); - } else if (auto opt = obj.as()) { - const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); - return Constant(opt.value(), StructInfo(), span); - } else { - TVM_FFI_THROW(InternalError) - << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; - } - }; - - for (const auto& [key, value] : untyped_params) { - relax_var_remap.Set(normalize_key(key), normalize_value(normalize_key(key), value)); - } - - arith::Analyzer analyzer; - ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); - - return {relax_var_remap, symbolic_var_map}; -} - -/*! - * \brief Bind params to function by using name with span name - * \param func Relax function - * \param params params dict - * \return Function - */ -Function FunctionBindNamedParams(Function func, - const ffi::Map& untyped_params) { - auto [bind_dict, symbolic_var_map] = NormalizeNamedBindings(func, untyped_params); - - Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); - return Downcast(bound_expr); -} - -/*! - * \brief Bind params to a specific function in a module with span name - * \param m The module - * \param func_name The name of the specific function - * \param param The param dict - * \return The module after binding params. - */ -IRModule BindNamedParam(IRModule m, ffi::String func_name, - ffi::Map bind_params) { - IRModuleNode* new_module = m.CopyOnWrite(); - ffi::Map functions = m->functions; - for (const auto& func_pr : functions) { - if (const auto* relax_f = func_pr.second.as()) { - if (relax_f->GetLinkageType() == LinkageType::kExternal) { - // Use global_symbol if it's external linkage - ffi::Optional gsymbol = - relax_f->GetAttr(tvm::attr::kGlobalSymbol); - if (gsymbol.has_value() && gsymbol.value() == func_name) { - Function f_after_bind = - FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); - new_module->Update(func_pr.first, f_after_bind); - } - } else { - // Use global var's name_hint if it's internal linkage - if (func_pr.first->name_hint == func_name) { - Function f_after_bind = - FunctionBindNamedParams(ffi::GetRef(relax_f), bind_params); - new_module->Update(func_pr.first, f_after_bind); - } - } - } - } - return ffi::GetRef(new_module); -} - -namespace transform { - -Pass BindNamedParams(ffi::String func_name, ffi::Map params) { - auto pass_func = [=](IRModule mod, PassContext pc) { - return BindNamedParam(std::move(mod), func_name, params); - }; - return CreateModulePass(pass_func, 0, "BindNamedParams", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.BindNamedParams", BindNamedParams); -} - -} // namespace transform - -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc deleted file mode 100644 index 4a196e0501f3..000000000000 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/fuse_shape.cc - * \brief Pass for fuse ShapeExpr. - */ - -#include -#include -#include -#include - -#include "../../../../relax/transform/utils.h" - -namespace tvm { -namespace relax { - -/*! - * \brief Bind ShapeExpr to Reshape - */ -class ShapeBinder : public ExprMutator { - public: - explicit ShapeBinder(IRModule ctx_module, const ffi::String& entry_name) - : ExprMutator(ctx_module) { - mod_ = ctx_module; - entry_name_ = entry_name; - } - - IRModule Bind() { - // update global functions - GlobalVar main_var; - for (const auto& [gv, func] : mod_->functions) { - if (gv->name_hint == entry_name_) { - main_var = gv; - continue; - } - if (func->IsInstance()) { - ffi::Array new_params; - for (const auto& p : Downcast(func)->params) { - auto struct_info = GetStructInfo(p); - if (struct_info->IsInstance()) { - continue; - } - new_params.push_back(p); - } - if (new_params.size() == Downcast(func)->params.size()) { - continue; - } - const auto& new_func = Downcast(VisitExpr(func)); - auto updated_func = Function(new_params, new_func->body, new_func->ret_struct_info, - new_func->is_pure, new_func->attrs, new_func->span); - builder_->UpdateFunction(gv, updated_func); - } - } - // update main - TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; - const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); - builder_->UpdateFunction(main_var, new_func); - return builder_->GetContextIRModule(); - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - ffi::Array new_args; - for (const auto& a : call_node->args) { - auto struct_info = GetStructInfo(a); - if (a->IsInstance() && struct_info->IsInstance()) { - continue; - } - if (call_node->op->IsInstance() && a->IsInstance()) { - continue; - } - new_args.push_back(a); - } - if (new_args.size() == call_node->args.size()) { - ExprMutator::VisitBinding_(binding, call_node); - } else if (const auto* op_node = call_node->op.as()) { - TVM_FFI_ICHECK(op_node->name == "relax.reshape" || op_node->name == "relax.image.resize2d") - << "Expect ShapeExpr consumer as reshape or image.resize2d, get " - << ffi::GetRef(call_node); - const auto& opt_shape = Downcast(GetStructInfo(call_node->args[1]))->values; - TVM_FFI_ICHECK(opt_shape.defined()) << "Expected shape defined, get " << call_node->args[1]; - new_args.push_back(ShapeExpr(opt_shape.value())); - const auto& new_call = - Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - } else if (const auto* gv_node = call_node->op.as()) { - const auto& func_info = Downcast(gv_node->struct_info_); - ffi::Array params_info; - for (const auto& a : new_args) { - TVM_FFI_ICHECK(a->struct_info_.defined()) - << "Global func argument without defined struct info " << a; - params_info.push_back(Downcast(a->struct_info_.value())); - } - call_node->op->struct_info_ = - FuncStructInfo(params_info, func_info->ret, func_info->purity, func_info->span); - const auto& new_call = - Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - } else { - LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); - } - } - - private: - IRModule mod_; - ffi::String entry_name_; -}; - -IRModule BindShape(IRModule mod, const ffi::String& entry_name) { - return ShapeBinder(mod, entry_name).Bind(); -} - -namespace transform { - -Pass BindShape(const ffi::String& entry_name) { - auto pass_func = [=](IRModule m, PassContext pc) { return relax::BindShape(m, entry_name); }; - return CreateModulePass(pass_func, 0, "BindShape", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.BindShape", BindShape); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc deleted file mode 100644 index 08d8a995f8cb..000000000000 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/fuse_tuple.cc - * \brief Pass for fuse ShapeExpr. - */ - -#include -#include -#include -#include -#include - -#include "../../../../relax/transform/utils.h" -#include "../utils.h" - -namespace tvm { -namespace relax { - -using namespace tvm::contrib::msc; - -/*! - * \brief Fuse Tuple and TupleGetItem to BYOC - */ -class TupleFuser : public ExprMutator { - public: - explicit TupleFuser(IRModule ctx_module, const ffi::String& target, const ffi::String& entry_name) - : ExprMutator(ctx_module) { - mod_ = ctx_module; - target_ = target + "."; - entry_name_ = entry_name; - } - - IRModule Fuse() { - GlobalVar main_var; - for (const auto& [gv, func] : mod_->functions) { - if (gv->name_hint == entry_name_) { - main_var = gv; - } else { - const auto& name_opt = func->GetAttr(attr::kComposite); - if (name_opt.has_value() && StringUtils::StartsWith(name_opt.value(), target_)) { - target_funcs_.Set(gv, Downcast(func)); - } - } - } - // update main - TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; - const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); - builder_->UpdateFunction(main_var, new_func); - return builder_->GetContextIRModule(); - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { - bool has_tuple_arg = false; - if (target_funcs_.count(val->op)) { - ffi::Array new_args; - for (size_t i = 0; i < val->args.size(); i++) { - const auto& arg = val->args[i]; - if (arg->IsInstance()) { - ffi::String tuple_name; - const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); - if (name_opt.has_value()) { - if (val->args.size() == 1) { - tuple_name = name_opt.value() + "_input"; - } else { - tuple_name = name_opt.value() + "_inputs." + std::to_string(i); - } - } - const auto& func_call = AddFunc(arg, tuple_name); - const auto& tuple_out = builder_->Emit(func_call); - TVM_FFI_ICHECK(target_funcs_.count(func_call->op)) - << "Can not find target func " << func_call->op; - target_funcs_.Set(tuple_out, target_funcs_[func_call->op]); - has_tuple_arg = true; - new_args.push_back(tuple_out); - } else { - new_args.push_back(arg); - } - if (has_tuple_arg) { - const auto& new_call = Call(val->op, new_args, val->attrs, val->sinfo_args, val->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - } - } - target_funcs_.Set(binding->var, target_funcs_[val->op]); - } - if (!has_tuple_arg) { - ExprMutator::VisitBinding_(binding, val); - } - } - - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { - bool on_target = true; - for (const auto& f : val->fields) { - if (!target_funcs_.count(f)) { - on_target = false; - break; - } - } - if (on_target) { - ReEmitFunc(binding, ffi::GetRef(val)); - } else { - ExprMutator::VisitBinding_(binding, val); - } - } - - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { - if (target_funcs_.count(val->tuple)) { - ReEmitFunc(binding, ffi::GetRef(val)); - } else { - ExprMutator::VisitBinding_(binding, val); - } - } - - private: - Call AddFunc(const Expr& expr, const ffi::String tuple_name = "") { - builder_->BeginDataflowBlock(); - ffi::Array inputs; - if (const auto* v_node = expr.as()) { - inputs = v_node->fields; - } else if (const auto* g_node = expr.as()) { - inputs = {g_node->tuple}; - } else { - LOG_FATAL << "Unexpceted expr " << expr; - } - ffi::Array func_inputs; - ffi::Array call_inputs; - ffi::Array params; - ffi::Map added_params; - for (size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->IsInstance()) { - func_inputs.push_back(inputs[i]); - continue; - } - if (!added_params.count(inputs[i])) { - const auto& name = ffi::String("param_" + std::to_string(i)); - const auto& var = Var(std::move(name), GetStructInfo(inputs[i])); - added_params.Set(inputs[i], var); - } - call_inputs.push_back(inputs[i]); - func_inputs.push_back(added_params[inputs[i]]); - params.push_back(added_params[inputs[i]]); - } - - Expr out_expr; - ffi::String func_name; - Span expr_span = expr->span; - if (!expr_span.defined()) { - TVM_FFI_ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; - expr_span = SpanUtils::CreateWithAttr(msc_attr::kName, tuple_name); - } - if (expr->IsInstance()) { - out_expr = Tuple(func_inputs, expr_span); - func_name = "tuple"; - } else if (const auto* g_node = expr.as()) { - out_expr = TupleGetItem(func_inputs[0], g_node->index, expr_span); - func_name = "get_item"; - } else { - LOG_FATAL << "Unexpceted expr " << expr; - } - - const auto& output = builder_->EmitOutput(out_expr); - BindingBlock new_block = builder_->EndBlock(); - Expr body = builder_->Normalize(output); - body = builder_->Normalize(SeqExpr({new_block}, body)); - - ffi::Map func_attrs; - func_attrs.Set(attr::kPrimitive, true); - func_attrs.Set(attr::kComposite, target_ + func_name); - func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr_span, msc_attr::kName)); - - Function function = Function(/*params=*/params, // - /*body=*/body, // - /*ret_struct_info=*/std::nullopt, // - /*is_pure=*/true, // - /*attrs=*/DictAttrs(func_attrs)); - ffi::Array free_vars = - FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); - if (!free_vars.empty()) { - params.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); - function = Function(/*params=*/params, // - /*body=*/body, // - /*ret_struct_info=*/std::nullopt, // - /*is_pure=*/true, // - /*attrs=*/DictAttrs(func_attrs)); - } - function = SymbolicVarRenewMutator::Renew(function); - GlobalVar gv = builder_->AddFunction(function, "fused_" + func_name); - target_funcs_.Set(gv, function); - return Call(gv, call_inputs); - } - - void ReEmitFunc(const VarBindingNode* binding, const Expr& expr) { - const auto& func_call = AddFunc(expr); - ReEmitBinding(binding, builder_->Normalize(func_call)); - TVM_FFI_ICHECK(target_funcs_.count(func_call->op)) - << "Can not find target func " << func_call->op; - target_funcs_.Set(binding->var, target_funcs_[func_call->op]); - } - - IRModule mod_; - ffi::String target_; - ffi::String entry_name_; - ffi::Map target_funcs_; -}; - -IRModule FuseTuple(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { - return TupleFuser(mod, target, entry_name).Fuse(); -} - -namespace transform { - -Pass FuseTuple(const ffi::String& target, const ffi::String& entry_name) { - auto pass_func = [=](IRModule m, PassContext pc) { - return relax::FuseTuple(m, target, entry_name); - }; - return CreateModulePass(pass_func, 0, "FuseTuple", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.FuseTuple", FuseTuple); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc deleted file mode 100644 index 14f8c7896649..000000000000 --- a/src/contrib/msc/core/transform/inline_params.cc +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/inline_params.cc - * \brief Pass for inline Exprs. - */ - -#include -#include -#include -#include - -#include "../../../../relax/transform/utils.h" -#include "../utils.h" - -namespace tvm { -namespace relax { - -using namespace tvm::contrib::msc; - -/*! - * \brief Inline the exprs - */ -class ParamsInliner : public ExprMutator { - public: - explicit ParamsInliner(IRModule ctx_module, const ffi::String& entry_name) - : ExprMutator(ctx_module) { - mod_ = ctx_module; - entry_name_ = entry_name; - } - - IRModule Bind() { - // update global functions - GlobalVar main_var; - for (const auto& [gv, func] : mod_->functions) { - if (gv->name_hint == entry_name_) { - main_var = gv; - continue; - } - if (func->IsInstance()) { - ffi::Array new_params; - ffi::Array attrs; - for (const auto& p : Downcast(func)->params) { - auto struct_info = GetStructInfo(p); - if (struct_info->IsInstance()) { - continue; - } - if (struct_info->IsInstance()) { - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - TVM_FFI_ICHECK(optype_opt.has_value()) - << "Can not find attr " << msc_attr::kOptype << " form extern func"; - extern_types_.Set(p, optype_opt.value()); - continue; - } - if (const auto* tuple_info = struct_info.as()) { - ffi::Array new_fields; - for (const auto& i : tuple_info->fields) { - if (i->IsInstance()) { - new_fields.push_back(i); - } else if (const auto& p_info = i.as()) { - TVM_FFI_ICHECK(p_info->value.defined()) - << "PrimStructInfo with undefined prim value " << i; - attrs.push_back(StringUtils::ToString(p_info->value.value())); - } - } - if (new_fields.size() < tuple_info->fields.size()) { - p->struct_info_ = TupleStructInfo(new_fields, tuple_info->span); - } - } - new_params.push_back(p); - } - if (new_params.size() == Downcast(func)->params.size()) { - continue; - } - const auto& new_func = Downcast(VisitExpr(func)); - ffi::Map func_attrs = new_func->attrs->dict; - if (attrs.size() > 0) { - func_attrs.Set(msc_attr::kOpattrs, attrs); - } - auto updated_func = Function(new_params, new_func->body, new_func->ret_struct_info, - new_func->is_pure, DictAttrs(func_attrs), new_func->span); - builder_->UpdateFunction(gv, updated_func); - } - } - // update main - TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; - const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); - builder_->UpdateFunction(main_var, new_func); - return builder_->GetContextIRModule(); - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - ffi::Array new_args; - bool has_inline = false; - for (const auto& a : call_node->args) { - auto struct_info = GetStructInfo(a); - if (a->IsInstance() && struct_info->IsInstance()) { - TVM_FFI_ICHECK(extern_types_.count(a)) << "Can not find extern type of " << a; - new_args.push_back(ExternFunc(extern_types_[a])); - has_inline = true; - } else if (call_node->op->IsInstance() && a->IsInstance()) { - has_inline = true; - } else if (a->IsInstance() && struct_info->IsInstance()) { - const auto& shape_opt = Downcast(GetStructInfo(a))->values; - TVM_FFI_ICHECK(shape_opt.defined()) << "Expected shape defined, get " << a; - new_args.push_back(ShapeExpr(shape_opt.value())); - has_inline = true; - } else if (call_node->op->IsInstance() && a->IsInstance()) { - has_inline = true; - } else if (call_node->op->IsInstance() && a->IsInstance()) { - const auto& tuple = Downcast(a); - ffi::Array new_fields; - ffi::Array new_infos; - - for (const auto& f : tuple->fields) { - if (f->IsInstance()) { - new_fields.push_back(f); - new_infos.push_back(GetStructInfo(f)); - } - } - if (new_fields.size() == tuple->fields.size()) { - new_args.push_back(a); - } else { - const auto& new_tuple = Tuple(new_fields, tuple->span); - new_tuple->struct_info_ = TupleStructInfo(new_infos); - new_args.push_back(new_tuple); - } - } else { - new_args.push_back(a); - } - } - if (!has_inline) { - ExprMutator::VisitBinding_(binding, call_node); - } else if (call_node->op->IsInstance()) { - const auto& new_call = - Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - } else if (const auto* gv_node = call_node->op.as()) { - const auto& func_info = Downcast(gv_node->struct_info_); - ffi::Array params_info; - for (const auto& a : new_args) { - TVM_FFI_ICHECK(a->struct_info_.defined()) - << "Global func argument without defined struct info " << a; - params_info.push_back(Downcast(a->struct_info_.value())); - } - call_node->op->struct_info_ = - FuncStructInfo(params_info, func_info->ret, func_info->purity, func_info->span); - const auto& new_call = - Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - } else { - LOG_FATAL << "Unexpected shape consumer " << ffi::GetRef(call_node); - } - } - - private: - IRModule mod_; - ffi::String entry_name_; - ffi::Map extern_types_; -}; - -IRModule InlineParams(IRModule mod, const ffi::String& entry_name) { - return ParamsInliner(mod, entry_name).Bind(); -} - -namespace transform { - -Pass InlineParams(const ffi::String& entry_name) { - auto pass_func = [=](IRModule m, PassContext pc) { return relax::InlineParams(m, entry_name); }; - return CreateModulePass(pass_func, 0, "InlineParams", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.InlineParams", InlineParams); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc deleted file mode 100644 index e5fdfabe4daa..000000000000 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/layout_utils.cc - */ -#include "layout_utils.h" - -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -NLayout LayoutUtils::InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) { - if (expr->IsInstance() && var_layout_map.count(Downcast(expr))) { - return tvm::relax::GetNLayout(var_layout_map, expr); - } - return GetNLayout(expr); -} - -LayoutDecision LayoutUtils::InferLayoutDecision(const Expr& expr, - const VarLayoutMap& var_layout_map) { - const auto& nlayout = InferNLayout(expr, var_layout_map); - TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; - return nlayout.LeafValue(); -} - -LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, - const VarLayoutMap& var_layout_map, - size_t index) { - const auto& nlayouts = InferNLayout(expr, var_layout_map); - if (nlayouts.IsLeaf()) { - return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); - } - const auto& nlayout = nlayouts.NestedArray()[0]; - TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; - return nlayout.LeafValue(); -} - -bool LayoutUtils::LayoutInfered(const Expr& expr) { - const ffi::String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - return layout.size() > 0; -} - -bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { - const ffi::String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - const auto& sinfo = GetStructInfo(expr); - if (sinfo->IsInstance() || sinfo->IsInstance()) { - if (!layout.IsLeaf()) { - return false; - } - const auto& l_layout = layout.LeafValue()->layout; - if (!l_layout.defined()) { - return false; - } - if (saved_layout == l_layout.name()) { - return false; - } - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kLayout, l_layout.name()); - } else if (sinfo->IsInstance()) { - if (layout.IsLeaf()) { - return false; - } - ffi::String layout_str; - ffi::Array nested_layouts = layout.NestedArray(); - for (size_t i = 0; i < nested_layouts.size(); i++) { - if (!nested_layouts[i].IsLeaf()) { - return false; - } - const auto& l_layout = nested_layouts[i].LeafValue()->layout; - if (!l_layout.defined()) { - return false; - } - layout_str = layout_str + l_layout.name() + (i < nested_layouts.size() - 1 ? "," : ""); - } - if (saved_layout == layout_str) { - return false; - } - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kLayout, layout_str); - } - return true; -} - -const NLayout LayoutUtils::GetNLayout(const Expr& expr) { - if (!LayoutInfered(expr)) { - return LayoutDecision(""); - } - auto sinfo = GetStructInfo(expr); - if (sinfo->IsInstance()) { - return LayoutDecision(SpanUtils::GetAttr(expr->span, msc_attr::kLayout)); - } - if (sinfo->IsInstance()) { - ffi::String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); - std::vector output_layout; - for (const auto& l : StringUtils::Split(layout_str, ",")) { - output_layout.push_back(LayoutDecision(l)); - } - return NLayout(output_layout); - } - return LayoutDecision(""); -} - -const LayoutDecision LayoutUtils::GetLayoutDecision(const Expr& expr) { - NLayout nlayout = GetNLayout(expr); - TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; - return nlayout.LeafValue(); -} - -bool LayoutUtils::HasUnknownDimTensor(const NLayout& nlayout) { - bool find = false; - auto fvisit = [&](const LayoutDecision& layout) { - find = find | (NLayoutEqual()(layout, LayoutDecision::InitUnknownDim())); - }; - ForEachLeaf(nlayout, fvisit); - return find; -} - -bool LayoutUtils::HasUnknownDimTensor(const ffi::Array& args) { - for (const auto& arg : args) { - if (IsNestedTensor(arg)) { - if (HasUnknownDimTensor(GetNLayout(arg))) { - return true; - } - } - } - return false; -} - -const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout, - const std::vector& expand_axes) { - if (!src_layout->layout.defined()) { - return src_layout; - } - // sort expand axes - std::vector axes = expand_axes; - std::sort(std::begin(axes), std::end(axes)); - std::string new_layout = src_layout.name(); - TVM_FFI_ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) - << "Only support normal layout, get " << src_layout->layout; - std::set used_axes; - for (size_t i = 0; i < src_layout->layout.ndim(); i++) { - used_axes.insert(src_layout->layout[i].name()); - } - std::vector prefer_axes{"N", "C", "H", "W", "D"}; - for (const auto& a : axes) { - bool use_prefer = false; - if (used_axes.size() < prefer_axes.size()) { - use_prefer = - std::all_of(prefer_axes.begin(), prefer_axes.begin() + used_axes.size(), - [&used_axes](const std::string& axis) { return used_axes.count(axis); }); - } - std::string new_axis; - char cur_axis = 'A'; - if (use_prefer) { - new_axis = prefer_axes[used_axes.size()]; - } else { - while (used_axes.count(std::string(1, cur_axis))) { - cur_axis += 1; - } - new_axis = std::string(1, cur_axis); - } - used_axes.insert(new_axis); - new_layout = new_layout.insert(a, new_axis); - } - return LayoutDecision(new_layout); -} - -const LayoutDecision LayoutUtils::ReduceLayout(const LayoutDecision& src_layout, - const std::vector& reduce_axes) { - if (!src_layout->layout.defined()) { - return src_layout; - } - std::set reduce_axes_set; - for (const auto& a : reduce_axes) { - reduce_axes_set.insert(a); - } - std::string new_layout = ""; - for (size_t i = 0; i < src_layout->layout.ndim(); i++) { - if (reduce_axes_set.count(i)) { - continue; - } - new_layout += src_layout->layout[i].name(); - } - return LayoutDecision(new_layout); -} - -const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, - const ffi::Array& axes) { - ffi::String layout_str; - for (const auto& a : axes) { - layout_str = layout_str + src_layout->layout[a->value].name(); - } - return LayoutDecision(layout_str); -} - -const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout, - const std::vector& axes) { - ffi::String layout_str; - for (const auto& a : axes) { - layout_str = layout_str + src_layout->layout[a].name(); - } - return LayoutDecision(layout_str); -} - -int LayoutUtils::InferBatchDim(const LayoutDecision& layout) { - if (!layout->layout.defined()) { - return -1; - } - for (size_t i = 0; i < layout->layout.ndim(); i++) { - if (layout->layout[i].name() == "N") { - return static_cast(i); - } - } - return -1; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h deleted file mode 100644 index 88bcc5703589..000000000000 --- a/src/contrib/msc/core/transform/layout_utils.h +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/layout_utils.h - * \brief Common utilities for layout. - */ -#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ -#define TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ - -#include -#include - -#include - -#include "../../../../relax/transform/infer_layout_utils.h" -#include "../../../../relax/transform/utils.h" -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using Expr = tvm::RelaxExpr; -using namespace tvm::relax; - -/*! - * \brief Utils for Layout. - */ -class LayoutUtils { - public: - /*! - * \brief Infer NLayout. - * \return The NLayout. - */ - TVM_DLL static NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map); - - /*! - * \brief Infer LayoutDecision. - * \return The LayoutDecision. - */ - TVM_DLL static LayoutDecision InferLayoutDecision(const Expr& expr, - const VarLayoutMap& var_layout_map); - - /*! - * \brief Infer LayoutDecision at given pos. - * \return The LayoutDecision. - */ - TVM_DLL static LayoutDecision InferLayoutDecisionAt(const Expr& expr, - const VarLayoutMap& var_layout_map, - size_t index = 0); - - /*! - * \brief Check if the layout is infered. - * \return Whether the layout is infered. - */ - TVM_DLL static bool LayoutInfered(const Expr& expr); - - /*! - * \brief Set the layout to span - * \return Whether the layout is setted. - */ - TVM_DLL static bool SetLayout(const Expr& expr, const NLayout& layout); - - /*! - * \brief Get the layout from span - * \return The NLayout. - */ - TVM_DLL static const NLayout GetNLayout(const Expr& expr); - - /*! - * \brief Get the layout desion from span - * \return The LayoutDecision. - */ - TVM_DLL static const LayoutDecision GetLayoutDecision(const Expr& expr); - - /*! - * \brief Check if the layout has unknown dim tensor. - * \return Whether the layout has unknown dim tensor. - */ - TVM_DLL static bool HasUnknownDimTensor(const NLayout& nlayout); - - /*! - * \brief Check if the args has unknown dim tensor. - * \return Whether the args has unknown dim tensor. - */ - TVM_DLL static bool HasUnknownDimTensor(const ffi::Array& args); - - /*! - * \brief Insert axes to the Layout - * \return The new layout. - */ - TVM_DLL static const LayoutDecision ExpandLayout(const LayoutDecision& src_layout, - const std::vector& expand_axes); - - /*! - * \brief Delete axes from the Layout - * \return The new layout. - */ - TVM_DLL static const LayoutDecision ReduceLayout(const LayoutDecision& src_layout, - const std::vector& reduce_axes); - /*! - * \brief Permute axes from the Layout - * \return The new layout. - */ - TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, - const ffi::Array& axes); - TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, - const std::vector& axes); - - /*! - * \brief Infer batch dim from the Layout - * \return The batch dim. - */ - TVM_DLL static int InferBatchDim(const LayoutDecision& layout); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_LAYOUT_UTILS_H_ diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc deleted file mode 100644 index a20e7d5ac3b0..000000000000 --- a/src/contrib/msc/core/transform/rewrite_utils.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/rewrite_utils.cc - */ -#include "rewrite_utils.h" - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -Var RewriteUtils::ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr) { - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); - return builder->Emit(expr, name); -} - -Var RewriteUtils::MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, - ffi::Array args, Attrs attrs) { - const auto& call = Call(op, args, attrs); - return ReEmit(builder, name, call); -} - -Expr RewriteUtils::MakeConstant(BlockBuilder builder, const ffi::String& name, double value, - const DataType& dtype, size_t ndim) { - const auto& data = support::FloatImmToTensor(FloatImm(dtype, value)); - Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); - const auto& constant = Constant(data, std::nullopt, span); - if (ndim == 0) { - return constant; - } - static const Op& reshape_op = Op::Get("relax.reshape"); - ffi::Array exp_shape(ndim, Integer(1)); - return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)}); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/transform/rewrite_utils.h b/src/contrib/msc/core/transform/rewrite_utils.h deleted file mode 100644 index b5dc5e4f2a64..000000000000 --- a/src/contrib/msc/core/transform/rewrite_utils.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/rewrite_utils.h - * \brief Common utilities for rewrite. - */ -#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ -#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ - -#include -#include - -#include - -#include "../../../../relax/transform/utils.h" -#include "../../../../support/scalars.h" -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using Expr = tvm::RelaxExpr; -using namespace tvm::relax; - -/*! - * \brief Utils for Layout. - */ -class RewriteUtils { - public: - /*! - * \brief Emit call with span name. - * \return The emitted var. - */ - TVM_DLL static Var ReEmit(BlockBuilder builder, const ffi::String& name, const Expr& expr); - - /*! - * \brief Make and emit a call binding with span. - * \return The emitted var. - */ - TVM_DLL static Var MakeCall(BlockBuilder builder, const ffi::String& name, Expr op, - ffi::Array args, Attrs attrs = Attrs()); - - /*! - * \brief Make and emit a (shaped)constant with span. - * \return The constant/reshape. - */ - TVM_DLL static Expr MakeConstant(BlockBuilder builder, const ffi::String& name, double value, - const DataType& dtype, size_t ndim = 0); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc deleted file mode 100644 index c459483481c7..000000000000 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/set_byoc_attrs.cc - * \brief Pass for fuse ShapeExpr. - */ - -#include -#include -#include -#include -#include - -#include "../../../../relax/transform/utils.h" -#include "../utils.h" - -namespace tvm { -namespace relax { - -using namespace tvm::contrib::msc; - -/*! - * \brief Fuse Tuple and TupleGetItem to BYOC - */ -class ByocNameSetter : public ExprMutator { - public: - explicit ByocNameSetter(IRModule ctx_module, const ffi::String& target, - const ffi::String& entry_name) - : ExprMutator(ctx_module) { - mod_ = ctx_module; - target_ = target; - entry_name_ = entry_name; - } - - IRModule SetNames() { - size_t func_cnt = 0; - for (const auto& [gv, func] : mod_->functions) { - if (gv->name_hint == entry_name_) { - continue; - } - const auto& name_opt = func->GetAttr(attr::kCodegen); - if (name_opt.has_value() && name_opt.value() == target_) { - const ffi::String& func_name = target_ + "_" + std::to_string(func_cnt); - const auto& new_func = Downcast(VisitExpr(func)); - builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name)); - func_cnt += 1; - } - } - return builder_->GetContextIRModule(); - } - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, ffi::GetRef(val)); - ExprMutator::VisitBinding_(binding, val); - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { - ExprMutator::VisitBinding_(binding, val); - if (val->op->IsInstance()) { - TVM_FFI_ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; - const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); - if (name_opt.has_value()) { - val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); - } - } - } - - private: - IRModule mod_; - ffi::String target_; - ffi::String entry_name_; - ffi::Map new_funcs_; - ffi::Map local_funcs_; -}; - -IRModule SetBYOCAttrs(IRModule mod, const ffi::String& target, const ffi::String& entry_name) { - return ByocNameSetter(mod, target, entry_name).SetNames(); -} - -namespace transform { - -Pass SetBYOCAttrs(const ffi::String& target, const ffi::String& entry_name) { - auto pass_func = [=](IRModule m, PassContext pc) { - return relax::SetBYOCAttrs(m, target, entry_name); - }; - return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.SetBYOCAttrs", SetBYOCAttrs); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc deleted file mode 100644 index 75273350afb4..000000000000 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ /dev/null @@ -1,1374 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/set_expr_layout.cc - * \brief Pass for setting layout for expr and constant. - */ - -#include -#include -#include -#include - -#include "../utils.h" -#include "layout_utils.h" - -namespace tvm { -namespace relax { - -using namespace tvm::contrib::msc; - -std::tuple AccumulateMatch(const ffi::Array& input_shape, - const ffi::Array& output_shape, - size_t in_start, size_t out_start) { - // find input position in_pos and output position out_pos - // cumsum(in_shape[in_start:in_pos])==cumsum(out_shape[out_start:out_pos]) - std::vector in_shape, out_shape; - for (const auto& s : input_shape) { - in_shape.push_back(Downcast(s)->value); - } - for (const auto& s : output_shape) { - out_shape.push_back(Downcast(s)->value); - } - int64_t in_size = static_cast(in_shape.size()); - int64_t out_size = static_cast(out_shape.size()); - int64_t in_pos = in_start; - int64_t out_pos = out_start; - int64_t in_accumulate = in_shape[in_pos]; - int64_t out_accumulate = out_shape[out_pos]; - while (in_accumulate != out_accumulate) { - if (in_accumulate > out_accumulate) { - out_pos += 1; - if (out_pos >= out_size) { - return std::make_tuple(-1, -1); - } - out_accumulate *= out_shape[out_pos]; - } else { - in_pos += 1; - if (in_pos >= in_size) { - return std::make_tuple(-1, -1); - } - in_accumulate *= in_shape[in_pos]; - } - } - if (in_accumulate != out_accumulate) { - return std::make_tuple(-1, -1); - } - // append tailing - if (in_pos >= 0) { - while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { - in_pos++; - } - while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { - out_pos++; - } - } - return std::make_tuple(in_pos - in_start, out_pos - out_start); -} - -std::tuple, std::vector> InferReshapeAxes( - const ffi::Array& input_shape, const ffi::Array& output_shape, - int batch_dim) { - std::vector expand_axes, reduce_axes; - size_t in_start = 0; - while (in_start < input_shape.size()) { - size_t out_start = in_start + expand_axes.size() - reduce_axes.size(); - int64_t in_dist, out_dist; - std::tie(in_dist, out_dist) = AccumulateMatch(input_shape, output_shape, in_start, out_start); - if (in_dist == -1) { - return std::make_tuple(std::vector(), std::vector()); - } - if (out_dist >= in_dist) { - for (size_t i = 0; i < static_cast(out_dist - in_dist); i++) { - if (batch_dim >= 0 && (out_start + i) == static_cast(batch_dim)) { - expand_axes.push_back(out_start + i + 1); - } else { - expand_axes.push_back(out_start + i); - } - } - } else { - for (size_t i = 0; i < static_cast(in_dist - out_dist); i++) { - if (batch_dim >= 0 && (in_start + i) == static_cast(batch_dim)) { - reduce_axes.push_back(in_start + i + 1); - } else { - reduce_axes.push_back(in_start + i); - } - } - } - in_start += in_dist + 1; - } - if (input_shape.size() + expand_axes.size() - reduce_axes.size() != output_shape.size()) { - return std::make_tuple(std::vector(), std::vector()); - } - return std::make_tuple(expand_axes, reduce_axes); -} - -// Forward and Backward infer -InferLayoutOutput MSCInferLayoutConv( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision data_layout, kernel_layout, out_layout; - const ffi::String& op_name = Downcast(call->op)->name; - if (op_name == "relax.nn.conv1d") { - const auto* attrs = call->attrs.as(); - data_layout = LayoutDecision(attrs->data_layout); - kernel_layout = LayoutDecision(attrs->kernel_layout); - out_layout = LayoutDecision(attrs->out_layout); - } else if (op_name == "relax.nn.conv2d") { - const auto* attrs = call->attrs.as(); - data_layout = LayoutDecision(attrs->data_layout); - kernel_layout = LayoutDecision(attrs->kernel_layout); - out_layout = LayoutDecision(attrs->out_layout); - } else if (op_name == "relax.nn.conv2d_transpose") { - const auto* attrs = call->attrs.as(); - data_layout = LayoutDecision(attrs->data_layout); - kernel_layout = LayoutDecision(attrs->kernel_layout); - out_layout = LayoutDecision(attrs->out_layout); - } - return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); -} - -InferLayoutOutput MSCInferLayoutPool2d( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision layout, out_layout; - const ffi::String& op_name = Downcast(call->op)->name; - if (op_name == "relax.nn.adaptive_avg_pool2d") { - const auto* attrs = call->attrs.as(); - layout = LayoutDecision(attrs->layout); - out_layout = LayoutDecision(attrs->out_layout); - } else { - const auto* attrs = call->attrs.as(); - layout = LayoutDecision(attrs->layout); - out_layout = LayoutDecision(attrs->out_layout); - } - return InferLayoutOutput({layout}, {out_layout}, Attrs()); -} - -InferLayoutOutput MSCInferLayoutResize2d( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto* attrs = call->attrs.as(); - const auto& data_layout = LayoutDecision(attrs->layout); - const auto& shape_layout = LayoutDecision("O"); - return InferLayoutOutput({data_layout, shape_layout}, {data_layout}, Attrs()); -} - -// Forward Infer -InferLayoutOutput ForwardInferLayoutCommon( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - ffi::Array input_layouts; - LayoutDecision layout_hint; - for (const auto& arg : call->args) { - const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); - if (in_layout->layout.defined()) { - layout_hint = in_layout; - } - input_layouts.push_back(in_layout); - } - if (!layout_hint.defined()) { - return InferLayoutOutput(); - } - const auto& sinfo = GetStructInfo(call); - if (sinfo->IsInstance()) { - return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); - } - ffi::Array output_layouts; - if (const auto* tuple_sinfo = sinfo.as()) { - for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { - output_layouts.push_back(layout_hint); - } - return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); - } - return InferLayoutOutput(); -} - -InferLayoutOutput ForwardInferLayoutBroadcast( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - ffi::Array input_layouts; - LayoutDecision layout_hint; - for (const auto& arg : call->args) { - const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); - if (in_layout->layout.defined()) { - if (!layout_hint.defined() || layout_hint->layout.ndim() < in_layout->layout.ndim()) { - layout_hint = in_layout; - } - } - input_layouts.push_back(in_layout); - } - if (!layout_hint.defined()) { - return InferLayoutOutput(); - } - const auto& sinfo = GetStructInfo(call); - if (sinfo->IsInstance()) { - return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); - } - return InferLayoutOutput(); -} - -InferLayoutOutput ForwardInferLayoutInplace( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); -} - -InferLayoutOutput ForwardInferLayoutBinary( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto& output = ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); - if (!output.defined()) { - return output; - } - std::vector input_layouts; - for (size_t i = 0; i < call->args.size(); i++) { - const auto& sinfo = GetStructInfo(call->args[i]); - if (const auto* t_info = sinfo.as()) { - if (t_info->ndim == 0) { - input_layouts.push_back(LayoutDecision("")); - } else if (t_info->ndim == 1) { - const auto& ref_layout = output->output_layouts[0].LeafValue()->layout; - input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name())); - } else { - input_layouts.push_back(output->input_layouts[i]); - } - } else { - TVM_FFI_THROW(InternalError) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); - } - } - return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutArgMaxMin( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - if (attrs->keepdims) { - return InferLayoutOutput({input_layout}, {input_layout}, Attrs()); - } - if (!attrs->axis.has_value()) { - return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - std::vector axes; - axes.push_back(CommonUtils::GetIndex(attrs->axis.value(), input_shape.size())); - LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutBatchNorm( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - LayoutDecision in_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!in_layout->layout.defined()) { - if (input_shape.size() == 4) { - in_layout = LayoutDecision("NCHW"); - } else if (input_shape.size() == 3) { - in_layout = LayoutDecision("NCD"); - } - } - LayoutDecision g_layout = LayoutDecision("O"); - return InferLayoutOutput({in_layout, g_layout, g_layout, g_layout, g_layout}, - {{in_layout, g_layout, g_layout}}, Attrs()); -} - -InferLayoutOutput ForkwardInferLayoutExpandDims( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - std::vector expand_axes; - for (const auto& s : attrs->axis) { - expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); - } - LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutNormalize( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - LayoutDecision in_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!in_layout->layout.defined()) { - if (input_shape.size() == 4) { - in_layout = LayoutDecision("NCHW"); - } else if (input_shape.size() == 3) { - in_layout = LayoutDecision("NCD"); - } - } - LayoutDecision g_layout = LayoutDecision("O"); - return InferLayoutOutput({in_layout, g_layout, g_layout}, {in_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutMatmul( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto& a_shape = ExprUtils::GetShape(call->args[0]); - const auto& b_shape = ExprUtils::GetShape(call->args[1]); - if (a_shape.size() == 0) { - return InferLayoutOutput(); - } - LayoutDecision a_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!a_layout->layout.defined()) { - if (a_shape.size() == 4) { - a_layout = LayoutDecision("NCHW"); - } else if (a_shape.size() == 3) { - a_layout = LayoutDecision("NCD"); - } else if (a_shape.size() == 2) { - a_layout = LayoutDecision("NC"); - } - } - size_t start = a_layout->layout.ndim() - b_shape.size(); - ffi::String pre_layout; - for (size_t i = start; i < a_layout->layout.ndim() - 2; i++) { - pre_layout = pre_layout + a_layout->layout[i].name(); - } - LayoutDecision b_layout = LayoutDecision(pre_layout + "IO"); - return InferLayoutOutput({a_layout, b_layout}, {a_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutPermute( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - std::vector permute_axes; - const auto* attrs = call->attrs.as(); - if (!attrs->axes.defined()) { - for (size_t i = input_layout->layout.ndim(); i > 0; i--) { - permute_axes.push_back(i - 1); - } - } else { - for (const auto& a : attrs->axes.value()) { - permute_axes.push_back(a->value); - } - } - LayoutDecision output_layout = LayoutUtils::PermuteLayout(input_layout, permute_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutReduceAxis( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - if (attrs->keepdims) { - return InferLayoutOutput({input_layout}, {input_layout}, Attrs()); - } - if (!attrs->axis.defined()) { - return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - std::vector axes; - for (const auto& s : attrs->axis.value()) { - axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); - } - LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutReshape( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& output_shape = ExprUtils::GetShape(call); - if (input_shape.size() == 0 || output_shape.size() == 0) { - return InferLayoutOutput(); - } - LayoutDecision output_layout = input_layout; - if (input_shape.size() != output_shape.size()) { - int batch_dim = LayoutUtils::InferBatchDim(input_layout); - std::vector expand_axes, reduce_axes; - std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); - if (reduce_axes.size() == 0 && expand_axes.size() == 0) { - return InferLayoutOutput(); - } - if (reduce_axes.size() > 0) { - output_layout = LayoutUtils::ReduceLayout(output_layout, reduce_axes); - } - if (expand_axes.size() > 0) { - output_layout = LayoutUtils::ExpandLayout(output_layout, expand_axes); - } - } - return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutSqueeze( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - if (!input_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - std::vector reduce_axes; - if (attrs->axis.defined()) { - for (const auto& s : attrs->axis.value()) { - size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size()); - if (Downcast(input_shape[v_index])->value == 1) { - reduce_axes.push_back(v_index); - } - } - } else { - for (size_t i = 0; i < input_shape.size(); i++) { - if (Downcast(input_shape[i])->value == 1) { - reduce_axes.push_back(i); - } - } - } - LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput ForwardInferLayoutTake( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& output_shape = ExprUtils::GetShape(call); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - if (input_layout->layout.defined()) { - if (input_shape.size() == output_shape.size()) { - return InferLayoutOutput({input_layout, indices_layout}, {input_layout}, Attrs()); - } - LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, std::vector{0}); - return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); - } - if (indices_layout->layout.defined()) { - std::vector expand_axes; - for (size_t i = indices_layout->layout.ndim(); i < output_shape.size(); i++) { - expand_axes.push_back(i); - } - LayoutDecision output_layout; - if (expand_axes.size() == 0) { - output_layout = indices_layout; - } else { - output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes); - } - return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); - } - return InferLayoutOutput(); -} - -InferLayoutOutput ForwardInferLayoutPlugin( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - if (!call->args[0]->IsInstance()) { - return InferLayoutOutput(); - } - const auto& name = Downcast(call->args[0])->global_symbol; - const auto pf = tvm::ffi::Function::GetGlobal("msc.plugin.op.InferLayout" + name); - if (!pf.has_value()) { - return InferLayoutOutput(); - } - const auto& args = Downcast(call->args[1]); - return (*pf)(args->fields, var_layout_map).cast(); -} - -// nn ops -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.conv1d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.conv2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.conv2d_transpose") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.dropout") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutCommon); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.max_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); - -// reduce axis ops -TVM_REGISTER_OP("relax.argmax") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutArgMaxMin); -TVM_REGISTER_OP("relax.argmin") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutArgMaxMin); -TVM_REGISTER_OP("relax.max") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.min") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.mean") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.sum") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.prod") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.std") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); - -// binary ops -TVM_REGISTER_OP("relax.add") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.divide") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.floor_divide") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.multiply") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.power") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.subtract") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.equal") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.greater") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.greater_equal") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.less") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.less_equal") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.not_equal") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.maximum") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.minimum") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_and") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_or") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_xor") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_and") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_or") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_xor") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); - -// math ops -TVM_REGISTER_OP("relax.expand_dims") - .set_attr("FMSCForwardInferLayout", ForkwardInferLayoutExpandDims); -TVM_REGISTER_OP("relax.matmul") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutMatmul); -TVM_REGISTER_OP("relax.permute_dims") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPermute); -TVM_REGISTER_OP("relax.reshape") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReshape); -TVM_REGISTER_OP("relax.squeeze") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutSqueeze); -TVM_REGISTER_OP("relax.take") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutTake); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutResize2d); - -// plugin op -TVM_REGISTER_OP("relax.call_dps_packed") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutPlugin); - -// Backward Infer -InferLayoutOutput BackwardInferLayoutCommon( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - NLayout output_layout = LayoutUtils::InferNLayout(call, var_layout_map); - LayoutDecision layout_hint; - if (output_layout.IsLeaf()) { - layout_hint = output_layout.LeafValue(); - } else { - for (const auto& l : output_layout.NestedArray()) { - if (l.IsLeaf() && l.LeafValue()->layout.defined()) { - layout_hint = l.LeafValue(); - } - } - } - if (!layout_hint->layout.defined()) { - return InferLayoutOutput(); - } - ffi::Array input_layouts; - for (const auto& arg : call->args) { - const auto& saved_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); - if (saved_layout->layout.defined()) { - input_layouts.push_back(saved_layout); - } else { - input_layouts.push_back(layout_hint); - } - } - return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutBinary( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - const auto& output = BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); - if (!output.defined()) { - return output; - } - std::vector input_layouts; - for (size_t i = 0; i < call->args.size(); i++) { - const auto& sinfo = GetStructInfo(call->args[i]); - if (const auto* t_info = sinfo.as()) { - if (t_info->ndim == 0) { - input_layouts.push_back(LayoutDecision("")); - } else if (t_info->ndim == 1) { - const auto& ref_layout = output->output_layouts[0].LeafValue()->layout; - input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name())); - } else { - input_layouts.push_back(output->input_layouts[i]); - } - } else { - TVM_FFI_THROW(InternalError) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); - } - } - return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutInplace( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); -} - -InferLayoutOutput BackwardInferLayoutArgMaxMin( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - if (attrs->keepdims) { - return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - std::vector axes; - axes.push_back(CommonUtils::GetIndex(Downcast(attrs->axis)->value, input_shape.size())); - LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutBatchNorm( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - LayoutDecision g_layout = LayoutDecision("O"); - return InferLayoutOutput({output_layout, g_layout, g_layout, g_layout, g_layout}, - {{output_layout, g_layout, g_layout}}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutExpandDims( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - std::vector expand_axes; - for (const auto& s : attrs->axis) { - expand_axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); - } - LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutNormalize( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - LayoutDecision g_layout = LayoutDecision("O"); - return InferLayoutOutput({output_layout, g_layout, g_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutMatmul( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& b_shape = ExprUtils::GetShape(call->args[1]); - if (b_shape.size() == 0) { - return InferLayoutOutput(); - } - size_t start = output_layout->layout.ndim() - b_shape.size(); - ffi::String pre_layout; - for (size_t i = start; i < output_layout->layout.ndim() - 2; i++) { - pre_layout = pre_layout + output_layout->layout[i].name(); - } - LayoutDecision b_layout = LayoutDecision(pre_layout + "IO"); - return InferLayoutOutput({output_layout, b_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutPermute( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - std::vector permute_axes; - const auto* attrs = call->attrs.as(); - if (!attrs->axes.defined()) { - for (size_t i = output_layout->layout.ndim(); i > 0; i--) { - permute_axes.push_back(i - 1); - } - } else { - std::vector attr_axes; - for (const auto& s : attrs->axes.value()) { - attr_axes.push_back(s->value); - } - for (size_t i = 0; i < output_layout->layout.ndim(); i++) { - int pos = ArrayUtils::IndexOf(attr_axes, static_cast(i)); - if (pos >= 0) { - permute_axes.push_back(pos); - } else { - permute_axes.push_back(i); - } - } - } - LayoutDecision input_layout = LayoutUtils::PermuteLayout(output_layout, permute_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutReduceAxis( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - if (attrs->keepdims) { - return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - std::vector axes; - for (const auto& s : attrs->axis.value()) { - axes.push_back(CommonUtils::GetIndex(s->value, input_shape.size())); - } - LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutReshape( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& output_shape = ExprUtils::GetShape(call); - if (input_shape.size() == 0 || output_shape.size() == 0) { - return InferLayoutOutput(); - } - LayoutDecision input_layout = output_layout; - if (input_shape.size() != output_shape.size()) { - int batch_dim = LayoutUtils::InferBatchDim(output_layout); - std::vector expand_axes, reduce_axes; - std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); - if (reduce_axes.size() == 0 && expand_axes.size() == 0) { - return InferLayoutOutput(); - } - if (expand_axes.size() > 0) { - input_layout = LayoutUtils::ReduceLayout(input_layout, expand_axes); - } - if (reduce_axes.size() > 0) { - input_layout = LayoutUtils::ExpandLayout(input_layout, reduce_axes); - } - } - return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutSqueeze( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - const auto* attrs = call->attrs.as(); - std::vector reduce_axes; - if (attrs->axis.defined()) { - for (const auto& s : attrs->axis.value()) { - size_t v_index = CommonUtils::GetIndex(s->value, input_shape.size()); - if (Downcast(input_shape[v_index])->value == 1) { - reduce_axes.push_back(v_index); - } - } - } else { - for (size_t i = 0; i < input_shape.size(); i++) { - if (Downcast(input_shape[i])->value == 1) { - reduce_axes.push_back(i); - } - } - } - LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes); - return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutTake( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); - LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& output_shape = ExprUtils::GetShape(call); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - if (input_shape.size() == 0) { - return InferLayoutOutput(); - } - if (!indices_layout.defined()) { - indices_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); - } - if (input_shape.size() == output_shape.size()) { - return InferLayoutOutput({output_layout, indices_layout}, {output_layout}, Attrs()); - } - if (!input_layout.defined()) { - input_layout = LayoutUtils::ExpandLayout(output_layout, std::vector{0}); - } - return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); -} - -InferLayoutOutput BackwardInferLayoutTupleInputs( - const Call& call, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); - if (!output_layout->layout.defined()) { - return InferLayoutOutput(); - } - std::vector input_layouts; - if (const auto* t_node = GetStructInfo(call->args[0]).as()) { - for (size_t i = 0; i < t_node->fields.size(); i++) { - input_layouts.push_back(output_layout); - } - } else { - LOG_FATAL << "Expected input as tuple, get " << call->args[0]; - } - return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); -} - -// nn ops -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.conv1d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.conv2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.conv2d_transpose") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.max_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); - -// reduce axis ops -TVM_REGISTER_OP("relax.argmax") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutArgMaxMin); -TVM_REGISTER_OP("relax.argmin") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutArgMaxMin); -TVM_REGISTER_OP("relax.max") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.min") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.mean") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.sum") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.prod") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); -TVM_REGISTER_OP("relax.std") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); - -// binary ops -TVM_REGISTER_OP("relax.add") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.divide") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.floor_divide") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.multiply") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.power") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.subtract") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.equal") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.greater") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.greater_equal") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.less") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.less_equal") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.not_equal") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.maximum") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.minimum") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_and") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_or") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.logical_xor") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_and") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_or") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); -TVM_REGISTER_OP("relax.bitwise_xor") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBinary); - -// math ops -TVM_REGISTER_OP("relax.concat") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutTupleInputs); -TVM_REGISTER_OP("relax.expand_dims") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutExpandDims); -TVM_REGISTER_OP("relax.matmul") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutMatmul); -TVM_REGISTER_OP("relax.permute_dims") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutPermute); -TVM_REGISTER_OP("relax.reshape") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReshape); -TVM_REGISTER_OP("relax.squeeze") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutSqueeze); -TVM_REGISTER_OP("relax.take") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutTake); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutResize2d); - -class LayoutInfer : public ExprVisitor { - public: - explicit LayoutInfer(const IRModule& ref_module) : ref_module_(ref_module) { Reset(); } - - void Reset() { - infered_ = false; - var_map_.clear(); - ordered_exprs_.clear(); - } - - void RecordExpr(const Var& var, const Expr& expr) { - var_map_.Set(var, expr); - ordered_exprs_.push_back(expr); - } - - Expr Infer(const Expr& expr) { - Reset(); - ForwardInfer(expr); - BackwardInfer(); - return expr; - } - - void ForwardInfer(const Expr& expr) { ExprVisitor::VisitExpr(expr); } - - void BackwardInfer() { - for (size_t e_idx = ordered_exprs_.size(); e_idx > 0; e_idx--) { - const Expr& expr = ordered_exprs_[e_idx - 1]; - if (expr->IsInstance()) { - continue; - } - if (expr->IsInstance()) { - continue; - } - if (expr->IsInstance()) { - continue; - } - if (!expr->IsInstance()) { - continue; - } - const Call& call = Downcast(expr); - if (const auto* v_node = call->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - BackwardInferFunc(func, call); - continue; - } else if (call->op->IsInstance() && local_funcs_.count(call->op)) { - BackwardInferFunc(local_funcs_[call->op], call); - continue; - } - size_t infered_num = 0; - for (const auto& arg : call->args) { - if (IsArgInfered(arg)) { - infered_num++; - } - } - if (call->args.size() == 0 || infered_num == call->args.size() || - !call->op->IsInstance() || LayoutUtils::HasUnknownDimTensor(call->args)) { - continue; - } - const OpNode* op_node = call->op.as(); - if (op_node == nullptr) { - continue; - } - // Infer by op_node - Op op = Downcast(ffi::GetRef(op_node)); - InferLayoutOutput infered_layout; - const auto& msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); - try { - if (msc_infer_map.count(op)) { - FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = - f(call, ffi::Map>(), var_layout_map_); - } else { - infered_layout = BackwardInferLayoutCommon( - call, ffi::Map>(), var_layout_map_); - } - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to backward infer layout " << expr << " : " << err.what(); - infered_layout = InferLayoutOutput(); - } - try { - if (infered_layout.defined()) { - SetInputLayouts(call, infered_layout->input_layouts); - } - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to backward set inputs layout for " << call << " : " << err.what(); - } - } - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - ExprVisitor::VisitBinding_(binding, call_node); - const auto& call = ffi::GetRef(call_node); - if (const auto* v_node = call->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - RecordExpr(binding->var, call); - ForwardInferFunc(func, call, binding->var); - } else if (call->op->IsInstance() && local_funcs_.count(call->op)) { - RecordExpr(binding->var, call); - ForwardInferFunc(local_funcs_[call->op], call, binding->var); - } else { - // infer call - bool infer_outputs = true; - RecordExpr(binding->var, call); - if (LayoutUtils::LayoutInfered(call)) { - infer_outputs = false; - } - if (call->args.size() == 0 || !call->op->IsInstance() || - LayoutUtils::HasUnknownDimTensor(call->args)) { - infer_outputs = false; - } - const OpNode* op_node = call->op.as(); - if (op_node == nullptr) { - infer_outputs = false; - } - if (infer_outputs) { - // infer layouts - Op op = Downcast(ffi::GetRef(op_node)); - InferLayoutOutput infered_layout; - const auto& msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); - const auto& relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); - bool set_inputs = true; - try { - if (msc_infer_map.count(op)) { - FRelaxInferLayout f = msc_infer_map[op]; - infered_layout = - f(call, ffi::Map>(), var_layout_map_); - } else if (!relax_infer_map.count(op)) { - infered_layout = ForwardInferLayoutCommon( - call, ffi::Map>(), var_layout_map_); - } - if (relax_infer_map.count(op) && !infered_layout.defined()) { - FRelaxInferLayout f = relax_infer_map[op]; - infered_layout = - f(call, ffi::Map>(), var_layout_map_); - set_inputs = false; - } - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to forward infer layout for " << binding->var << " : " - << binding->value << ", reason: " << err.what(); - infered_layout = InferLayoutOutput(); - } - if (infered_layout.defined() && infered_layout->output_layouts.size() == 1) { - try { - SetExprLayout(binding->var, infered_layout->output_layouts[0]); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to forward set output layout for " << binding->var << " : " - << binding->value << ", reason: " << err.what(); - } - } - if (set_inputs && infered_layout.defined()) { - try { - SetInputLayouts(call, infered_layout->input_layouts); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to forward set inputs layout for " << call << " : " - << err.what(); - } - } - } - } - } - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { - local_funcs_.Set(binding->var, ffi::GetRef(val)); - } - - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { - ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, ffi::GetRef(val)); - if (IsNestedTensor(binding->var)) { - ffi::Array input_layouts; - for (const auto& field : val->fields) { - input_layouts.push_back(LayoutUtils::InferLayoutDecision(field, var_layout_map_)); - } - SetExprLayout(binding->var, input_layouts); - } - } - - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { - ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, ffi::GetRef(val)); - const auto& out_layout = LayoutUtils::InferLayoutDecisionAt( - ffi::GetRef(val)->tuple, var_layout_map_, val->index); - SetExprLayout(binding->var, out_layout); - } - - void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final { - ExprVisitor::VisitBinding_(binding, val); - RecordExpr(binding->var, ffi::GetRef(val)); - SetExprLayout(binding->var, LayoutDecision("O")); - } - - bool infered() { return infered_; } - - private: - bool IsArgInfered(const Expr& arg) { - if (arg->IsInstance() && var_map_.count(Downcast(arg))) { - if (LayoutUtils::LayoutInfered(var_map_[Downcast(arg)]) > 0) { - return true; - } - } else if (const auto* tuple_node = arg.as()) { - for (const auto& field : tuple_node->fields) { - if (!IsArgInfered(field)) { - return false; - } - } - return true; - } else if (LayoutUtils::LayoutInfered(arg)) { - return true; - } - return false; - } - - void SetExprLayout(const Expr& expr, const NLayout& layout) { - if (expr->IsInstance()) { - const auto& var = Downcast(expr); - var_layout_map_[var] = layout; - if (LayoutUtils::SetLayout(var, layout)) { - infered_ = true; - } - if (var_map_.count(var) && LayoutUtils::SetLayout(var_map_[var], layout)) { - infered_ = true; - } - } else if (LayoutUtils::SetLayout(expr, layout)) { - infered_ = true; - } - } - - void SetInputLayouts(const Call& call, const ffi::Array& input_layouts) { - if (input_layouts.size() == call->args.size()) { - for (size_t i = 0; i < input_layouts.size(); i++) { - SetExprLayout(call->args[i], input_layouts[i]); - } - } - } - - void ForwardInferFunc(const Function& func, const Call& call, const Var& ret) { - for (size_t i = 0; i < call->args.size(); i++) { - if (call->args[i]->IsInstance() && - var_layout_map_.count(Downcast(call->args[i]))) { - SetExprLayout(func->params[i], var_layout_map_[Downcast(call->args[i])]); - } - } - ForwardInfer(func); - for (size_t i = 0; i < func->params.size(); i++) { - if (var_layout_map_.count(func->params[i])) { - SetExprLayout(call->args[i], var_layout_map_[func->params[i]]); - } - } - if (const auto* b_node = func->body.as()) { - if (b_node->body->IsInstance() && - var_layout_map_.count(Downcast(b_node->body))) { - SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); - } - } else { - TVM_FFI_THROW(InternalError) << "Function body should be SeqExpr, get " << func->body; - } - } - - void BackwardInferFunc(const Function& func, const Call& call) { - for (size_t i = 0; i < func->params.size(); i++) { - if (var_layout_map_.count(func->params[i])) { - const auto& param_layout = var_layout_map_[func->params[i]]; - SetExprLayout(call->args[i], param_layout); - if (call->args[i]->IsInstance() && var_map_.count(Downcast(call->args[i]))) { - const auto& producer = var_map_[Downcast(call->args[i])]; - if (producer->IsInstance() && - local_funcs_.count(Downcast(producer)->op)) { - const auto& caller = local_funcs_[Downcast(producer)->op]; - if (const auto* b_node = caller->body.as()) { - if (b_node->body->IsInstance() && - var_map_.count(Downcast(b_node->body))) { - SetExprLayout(b_node->body, param_layout); - } - } else { - TVM_FFI_THROW(InternalError) << "Caller body should be SeqExpr, get " << caller->body; - } - } - } - } - } - } - - IRModule ref_module_; - bool infered_; - ffi::Map var_map_; - ffi::Array ordered_exprs_; - std::unordered_map var_layout_map_; - ffi::Map local_funcs_; -}; // class LayoutInfer - -class LayoutChecker : public ExprVisitor { - public: - LayoutChecker() { missing_num_ = 0; } - - void Check(const Expr& expr) { - ExprVisitor::VisitExpr(expr); - TVM_FFI_ICHECK_EQ(missing_num_, 0) << "Some layout is missing"; - } - - void VisitExpr_(const CallNode* call) final { - ExprVisitor::VisitExpr_(call); - if (!LayoutUtils::LayoutInfered(ffi::GetRef(call))) { - missing_num_++; - } - } - - void VisitExpr_(const ConstantNode* cn) final { - ExprVisitor::VisitExpr_(cn); - if (!LayoutUtils::LayoutInfered(ffi::GetRef(cn))) { - missing_num_++; - } - } - - private: - size_t missing_num_; -}; // class LayoutChecker - -void SetExprLayout(const IRModule& ref_module, const Expr& func, bool allow_missing) { - auto layout_infer = LayoutInfer(ref_module); - auto new_func = layout_infer.Infer(func); - if (!allow_missing) { - LayoutChecker().Check(new_func); - } -} - -namespace transform { - -Pass SetExprLayout(bool allow_missing, const ffi::String& entry_name) { - auto pass_func = [=](IRModule m, PassContext pc) { - relax::SetExprLayout(m, m->Lookup(entry_name), allow_missing); - return m; - }; - return CreateModulePass(pass_func, 0, "SetExprLayout", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.SetExprLayout", SetExprLayout); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc deleted file mode 100644 index 73c8f9d8e879..000000000000 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/transform/set_expr_name.cc - * \brief Pass for setting name for call and constant. - */ - -#include -#include -#include -#include - -#include - -#include "../utils.h" - -namespace tvm { -using namespace tvm::contrib::msc; - -namespace relax { - -class FuncNameGetter : public ExprVisitor { - public: - explicit FuncNameGetter(const ffi::Array& arg_names) : arg_names_(arg_names) {} - - /*! \brief Get the attributes from prim value as ffi::Map*/ - ffi::String HintName(const Expr& expr) { - name_ = ""; - ExprVisitor::VisitExpr(expr); - return name_; - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { - if (name_.size() == 0) { - name_ = SpanUtils::GetAttr(val->span, msc_attr::kName); - } - if (name_.size() == 0) { - ExprVisitor::VisitBinding_(binding, val); - } - } - - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { - if (name_.size() == 0) { - name_ = SpanUtils::GetAttr(val->span, msc_attr::kName); - } - if (name_.size() == 0) { - ExprVisitor::VisitBinding_(binding, val); - } - } - - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - if (name_.size() == 0 && arg_names_[0].size() > 0) { - name_ = arg_names_[0] + "." + std::to_string(val->index); - } - if (name_.size() == 0) { - ExprVisitor::VisitBinding_(binding, val); - } - } - - private: - ffi::String name_; - ffi::Array arg_names_; -}; - -/*! - * \brief Name setter for Relax - */ -class RelaxExprNameSetter : public ExprVisitor { - public: - explicit RelaxExprNameSetter(const IRModule& ref_module, const ffi::String& target, - const ffi::Map& var_names) - : ref_module_(ref_module), target_{target}, var_names_{var_names} {} - - void VisitBindingBlock(const BindingBlock& block) final { - ffi::String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); - if (block_name.size() == 0) { - block_name = "block"; - } - const ffi::String& prefix = StringUtils::Join(block_stack_, "."); - if (setted_blocks_.count(prefix + "." + block_name)) { - int cnt = 1; - while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { - cnt++; - } - block_name = block_name + "_" + std::to_string(cnt); - } - setted_blocks_.insert(prefix + "." + block_name); - block_stack_.push_back(block_name); - const ffi::String& unique_name = StringUtils::Join(block_stack_, "."); - block->span = SpanUtils::SetAttr(block->span, msc_attr::kName, unique_name); - ExprVisitor::VisitBindingBlock(block); - block_stack_.pop_back(); - } - - void VisitExpr_(const ConstantNode* val) { - ExprVisitor::VisitExpr_(val); - const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - expr_names_.Set(ffi::GetRef(val), unique_name); - } - - void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "const"); - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - expr_names_.Set(binding->var, unique_name); - } - - void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "shape"); - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - expr_names_.Set(binding->var, unique_name); - } - - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const ffi::String& unique_name = GetUniqueName(ffi::GetRef(val), "tuple"); - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - expr_names_.Set(binding->var, unique_name); - } - - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { - ExprVisitor::VisitBinding_(binding, val); - ffi::String unique_name; - if (expr_names_.count(val->tuple)) { - unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); - } else if (const auto* v_node = val->tuple.as()) { - unique_name = v_node->name_hint() + "." + std::to_string(val->index); - } - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - expr_names_.Set(binding->var, unique_name); - } - - void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - ExprVisitor::VisitBinding_(binding, val); - const auto& name_opt = val->GetAttr(attr::kComposite); - if (name_opt.has_value()) { - local_funcs_.Set(binding->var, ffi::GetRef(val)); - } - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { - ExprVisitor::VisitBinding_(binding, val); - ffi::String name_hint, optype; - bool use_unique = true; - if (var_names_.count(binding->var->name_hint())) { - name_hint = var_names_[binding->var->name_hint()]; - } else if (const auto* op_node = val->op.as()) { - const std::string& op_name = op_node->name; - if (op_name == "relax.call_dps_packed" && val->args[0]->IsInstance()) { - const auto& func = Downcast(val->args[0]); - name_hint = func->global_symbol; - optype = func->global_symbol; - const ffi::String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); - if (input_name != SpanUtils::GetAttr(val->args[1]->span, msc_attr::kName)) { - val->args[1]->span = SpanUtils::SetAttr(val->args[1]->span, msc_attr::kName, input_name); - } - } else { - int rpos = op_name.rfind("."); - name_hint = op_name.substr(rpos + 1); - optype = StringUtils::Replace(op_node->name, "relax.", ""); - } - } else if (const auto* v_node = val->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - ExprVisitor::VisitExpr(func); - optype = GetFuncType(func); - name_hint = GetFuncName(ffi::GetRef(val), func); - use_unique = false; - } else if (local_funcs_.count(val->op)) { - ExprVisitor::VisitExpr(local_funcs_[val->op]); - optype = GetFuncType(local_funcs_[val->op]); - name_hint = GetFuncName(ffi::GetRef(val), local_funcs_[val->op]); - use_unique = false; - } - if (name_hint.size() > 0) { - // set name - const ffi::String& unique_name = - use_unique ? GetUniqueName(ffi::GetRef(val), name_hint) : name_hint; - if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); - } - // set constant consumer && shared_ref - ffi::Array input_types; - try { - input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << ffi::GetRef(val) << " : " - << err.what(); - throw err; - } - for (size_t i = 0; i < input_types.size(); i++) { - if (input_types[i] == "input") { - continue; - } - if (const auto* c_node = val->args[i].as()) { - const ffi::String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); - if (constant_consumers_.count(const_name)) { - val->span = SpanUtils::SetAttr(val->span, msc_attr::kSharedRef, - constant_consumers_[const_name]); - } else { - constant_consumers_.Set(const_name, unique_name); - } - } - } - expr_names_.Set(binding->var, unique_name); - } - } - - private: - const ffi::String GetUniqueName(const Expr& expr, const ffi::String& name_hint) { - ffi::String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); - if (expr_name.size() == 0) { - expr_name = name_hint; - } - if (!setted_names_.count(expr_name)) { - setted_names_.Set(expr_name, expr); - return expr_name; - } - if (setted_names_[expr_name] == expr) { - return expr_name; - } - int cnt = 1; - while (setted_names_.count(expr_name + "_" + std::to_string(cnt)) && - setted_names_[expr_name + "_" + std::to_string(cnt)] != expr) { - cnt++; - } - expr_name = expr_name + "_" + std::to_string(cnt); - if (!setted_names_.count(expr_name)) { - setted_names_.Set(expr_name, expr); - } - return expr_name; - } - - const ffi::String GetFuncType(const Function& func) { - ffi::String optype; - const auto& comp_opt = func->GetAttr(attr::kComposite); - const auto& code_opt = func->GetAttr(attr::kCodegen); - if (comp_opt.has_value()) { - optype = comp_opt.value(); - } else if (code_opt.has_value()) { - optype = code_opt.value(); - } else { - optype = "extern_func"; - } - if (target_.size() > 0) { - optype = StringUtils::Replace(optype, target_ + ".", ""); - } - return optype; - } - - const ffi::String GetFuncName(const Call& call, const Function& func) { - ffi::String name; - // get from unique - const auto& name_opt = func->GetAttr(msc_attr::kUnique); - if (name_opt.has_value()) { - return name_opt.value(); - } - // get from exprs in the func - ffi::Array arg_names; - for (const auto& a : call->args) { - arg_names.push_back(expr_names_.count(a) ? expr_names_[a] : ""); - } - name = FuncNameGetter(arg_names).HintName(local_funcs_[call->op]); - if (name.size() > 0) { - return name; - } - const auto& optype = GetFuncType(func); - if (optype == "extern_func") { - name = Downcast(call->op)->name_hint(); - } else { - name = optype; - } - return GetUniqueName(call, name); - } - - ffi::Map setted_names_; - ffi::Map constant_consumers_; - std::set setted_blocks_; - ffi::Array block_stack_; - ffi::Map expr_names_; - ffi::Map local_funcs_; - IRModule ref_module_; - ffi::String target_; - ffi::Map var_names_; -}; // class ExprNameSetter - -void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const ffi::String& target, - const ffi::Map& var_names) { - RelaxExprNameSetter(ref_module, target, var_names).VisitExpr(e); -} - -namespace transform { - -Pass SetRelaxExprName(const ffi::String& entry_name, const ffi::String& target, - const ffi::Map& var_names) { - auto pass_func = [=](IRModule m, PassContext pc) { - relax::SetRelaxExprName(m, m->Lookup(entry_name), target, var_names); - return m; - }; - return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.SetRelaxExprName", SetRelaxExprName); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc deleted file mode 100644 index cf25392c8973..000000000000 --- a/src/contrib/msc/core/utils.cc +++ /dev/null @@ -1,562 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/utils.cc - */ - -#include "utils.h" - -#include - -#include -#include -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; - -size_t CommonUtils::GetIndex(int index, size_t max_size) { - size_t v_index; - if (index < 0) { - v_index = index + max_size; - } else { - v_index = index; - } - TVM_FFI_ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " << max_size; - return v_index; -} - -std::vector CommonUtils::GetIndices(const std::vector& indices, size_t max_size) { - std::vector v_indices; - for (const auto& i : indices) { - v_indices.push_back(GetIndex(i, max_size)); - } - return v_indices; -} - -int CommonUtils::CompareVersion(const std::vector& given_version, - const std::vector& target_version) { - if (given_version.size() == 0 || target_version.size() == 0) { - return 0; - } - TVM_FFI_ICHECK_EQ(given_version.size(), 3) << "Version should be in format major,minor,patch"; - TVM_FFI_ICHECK_EQ(target_version.size(), 3) - << "Target version should be in format major,minor,patch"; - for (size_t i = 0; i < 3; i++) { - if (given_version[i] > target_version[i]) { - return 1; - } else if (given_version[i] < target_version[i]) { - return -1; - } - } - return 0; -} - -int CommonUtils::CompareVersion(const ffi::Array& given_version, - const ffi::Array& target_version) { - std::vector int_given_version; - std::vector int_target_version; - for (const auto& v : given_version) { - int_given_version.push_back(static_cast(v->value)); - } - for (const auto& v : target_version) { - int_target_version.push_back(static_cast(v->value)); - } - return CompareVersion(int_given_version, int_target_version); -} - -const ffi::String CommonUtils::ToAttrKey(const ffi::String& key) { - if (key == "name") { - return msc_attr::kName; - } - if (key == "optype") { - return msc_attr::kOptype; - } - if (key == "op_attrs") { - return msc_attr::kOpattrs; - } - if (key == "layout") { - return msc_attr::kLayout; - } - if (key == "shared_ref") { - return msc_attr::kSharedRef; - } - if (key == "unique") { - return msc_attr::kUnique; - } - if (key == "input_layouts") { - return msc_attr::kInputLayouts; - } - if (key == "consumer_type") { - return msc_attr::kConsumerType; - } - LOG_FATAL << "Unexpected key " << key; - TVM_FFI_UNREACHABLE(); -} - -bool StringUtils::Contains(const ffi::String& src_string, const ffi::String& sub_string) { - if (src_string.size() == 0) { - return false; - } - if (sub_string.size() == 0) { - return false; - } - - const std::string& src_cstring = src_string; - const std::string& sub_cstring = sub_string; - int pos = src_cstring.find(sub_cstring); - return pos >= 0; -} - -bool StringUtils::StartsWith(const ffi::String& src_string, const ffi::String& sub_string) { - if (src_string.size() == 0) { - return false; - } - if (sub_string.size() == 0) { - return false; - } - const std::string& src_cstring = src_string; - const std::string& sub_cstring = sub_string; - int pos = src_cstring.find(sub_cstring); - return pos == 0; -} - -bool StringUtils::EndsWith(const ffi::String& src_string, const ffi::String& sub_string) { - if (src_string.size() == 0) { - return false; - } - if (sub_string.size() == 0) { - return false; - } - const std::string& src_cstring = src_string; - const std::string& sub_cstring = sub_string; - int pos = src_cstring.rfind(sub_cstring); - if (pos < 0) { - return false; - } - return static_cast(pos) == src_cstring.size() - sub_cstring.size(); -} - -const ffi::Array StringUtils::Split(const ffi::String& src_string, - const ffi::String& sep) { - ffi::Array sub_strings; - if (src_string.size() == 0) { - return sub_strings; - } - std::string src_cstring = src_string; - const std::string& csep = sep; - int pos = src_cstring.find(csep); - while (pos >= 0) { - if (pos > 0) { - sub_strings.push_back(src_cstring.substr(0, pos)); - } - src_cstring = src_cstring.substr(pos + csep.size()); - pos = src_cstring.find(csep); - } - if (src_cstring.size() > 0) { - sub_strings.push_back(src_cstring); - } - return sub_strings; -} - -const ffi::String StringUtils::Join(const ffi::Array& sub_strings, - const ffi::String& joint) { - ffi::String join_str = ""; - for (size_t i = 0; i < sub_strings.size(); i++) { - join_str = join_str + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : joint); - } - return join_str; -} - -const ffi::String StringUtils::Join(const std::vector& sub_strings, - const std::string& joint) { - ffi::Array new_strings; - for (const auto& s : sub_strings) { - new_strings.push_back(s); - } - return Join(new_strings, joint); -} - -const ffi::String StringUtils::Replace(const ffi::String& src_string, const ffi::String& old_str, - const ffi::String& new_str) { - ffi::String new_string; - const auto& sub_strings = Split(src_string, old_str); - for (size_t i = 0; i < sub_strings.size(); i++) { - new_string = new_string + sub_strings[i] + (i == sub_strings.size() - 1 ? "" : new_str); - } - return new_string; -} - -const std::tuple StringUtils::SplitOnce(const ffi::String& src_string, - const ffi::String& sep, - bool from_left) { - if (src_string.size() == 0) { - return std::make_tuple(ffi::String(), ffi::String()); - } - std::string src_cstring = src_string; - const std::string& csep = sep; - int pos = from_left ? src_cstring.find(csep) : src_cstring.rfind(csep); - if (pos >= 0) { - return std::make_tuple(src_cstring.substr(0, pos), src_cstring.substr(pos + csep.size())); - } - return std::make_tuple(src_string, ffi::String()); -} - -const ffi::Array StringUtils::GetClosures(const ffi::String& src_string, - const ffi::String& left, - const ffi::String& right) { - ffi::Array tokens; - if (src_string.size() == 0) { - return tokens; - } - ffi::String token = "start"; - ffi::String left_str = src_string; - while (token.size() > 0) { - std::tie(token, left_str) = StringUtils::SplitOnce(left_str, left); - if (left_str.size() > 0) { - std::tie(token, left_str) = StringUtils::SplitOnce(left_str, right); - } else { - token = ""; - } - if (token.size() > 0) { - tokens.push_back(token); - } - } - return tokens; -} - -const ffi::String StringUtils::GetClosureOnce(const ffi::String& src_string, - const ffi::String& left, const ffi::String& right, - bool from_left) { - if (src_string.size() == 0) { - return ""; - } - ffi::String val = std::get<1>(SplitOnce(src_string, left, from_left)); - if (val.size() > 0) { - val = std::get<0>(StringUtils::SplitOnce(val, right, from_left)); - } - return val; -} - -const ffi::String StringUtils::Upper(const ffi::String& src_string) { - std::string str = std::string(src_string); - std::transform(str.begin(), str.end(), str.begin(), ::toupper); - return str; -} - -const ffi::String StringUtils::Lower(const ffi::String& src_string) { - std::string str = std::string(src_string); - std::transform(str.begin(), str.end(), str.begin(), ::tolower); - return str; -} - -const ffi::String StringUtils::ToString(const ffi::Any& obj) { - ffi::String obj_string; - if (obj == nullptr) { - obj_string = ""; - } else if (auto opt_str = obj.as()) { - obj_string = *opt_str; - } else if (const auto* n = obj.as()) { - obj_string = std::to_string(n->value); - } else if (obj.type_index() == kTVMFFIInt) { - obj_string = std::to_string(obj.cast()); - } else if (const auto* n = obj.as()) { - obj_string = std::to_string(n->value); - } else if (const auto* n = obj.as()) { - for (size_t i = 0; i < n->size(); i++) { - obj_string = obj_string + ToString((*n)[i]); - if (n->size() == 1 || i < n->size() - 1) { - obj_string = obj_string + ","; - } - } - } else if (const auto* n = obj.as()) { - obj_string = ToString(n->value); - } else if (const auto* n = obj.as()) { - obj_string = ToString(n->fields); - } else { - std::ostringstream obj_des; - obj_des << obj; - obj_string = obj_des.str(); - } - return obj_string; -} - -bool ArrayUtils::CompareArrays(const ffi::Array& left, - const ffi::Array& right, int size) { - if (left.size() == right.size() && left.size() == 0) { - return true; - } - if (size == -1 && left.size() != right.size()) { - return false; - } - if (left.size() == 0 || right.size() == 0) { - return false; - } - size = left.size(); - TVM_FFI_ICHECK_GT(size, 0) << "Positive size should be given, get " << size; - if (size > static_cast(left.size()) || size > static_cast(right.size())) { - return false; - } - for (size_t i = 0; i < static_cast(size); i++) { - if (left[i] != right[i]) { - return false; - } - } - return true; -} - -PrimExpr ArrayUtils::Accumulate(const ffi::Array& array, int pos) { - size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos; - PrimExpr accumulate = Integer(1); - for (size_t i = 0; i < t_pos; i++) { - accumulate = accumulate * array[i]; - } - return accumulate; -} - -bool ArrayUtils::Broadcastable(const ffi::Array& lhs, const ffi::Array& rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - for (size_t i = 0; i < lhs.size(); i++) { - const auto& lp = lhs[i]; - const auto& rp = rhs[i]; - if (lp->IsInstance() && rp->IsInstance()) { - continue; - } - if (lp->IsInstance() && rp->IsInstance() && - Downcast(lp)->value == Downcast(rp)->value) { - continue; - } - if (lp->IsInstance() && Downcast(lp)->value == 1) { - continue; - } - return false; - } - return true; -} - -const Span SpanUtils::SetAttr(const Span& span, const ffi::String& key, const ffi::String& value) { - if (value.size() == 0) { - return span; - } - ffi::String new_source; - ffi::Array tokens{"<" + key + ">", ""}; - if (span.defined() && span->source_name.defined()) { - const ffi::String& source_str = span->source_name->name; - ffi::String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); - ffi::String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); - if (StringUtils::Contains(source_str, tokens[0]) && - StringUtils::Contains(source_str, tokens[1])) { - new_source = left + tokens[0] + value + tokens[1] + right; - } else { - new_source = source_str + tokens[0] + value + tokens[1]; - } - } else { - new_source = tokens[0] + value + tokens[1]; - } - if (span.defined()) { - return Span(SourceName::Get(new_source), span->line, span->end_line, span->column, - span->end_column); - } - return Span(SourceName::Get(new_source), 0, 0, 0, 0); -} - -ffi::String SpanUtils::GetAttr(const Span& span, const ffi::String& key) { - if (span.defined() && span->source_name.defined()) { - ffi::Array tokens{"<" + key + ">", ""}; - return StringUtils::GetClosureOnce(span->source_name->name, tokens[0], tokens[1]); - } - return ""; -} - -const ffi::Map SpanUtils::GetAttrs(const Span& span) { - ffi::Map attrs; - for (const auto& key : StringUtils::GetClosures(span->source_name->name, "")) { - attrs.Set(key, GetAttr(span, key)); - } - return attrs; -} - -const Span SpanUtils::CreateWithAttr(const ffi::String& key, const ffi::String& value) { - return SetAttr(Span(), key, value); -} - -const ffi::Array ExprUtils::GetInputTypes(const ffi::String& optype, size_t inputs_num, - bool as_relax) { - ffi::Array input_types; - if (as_relax && (optype == "broadcast_to" || optype == "reshape")) { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("shape"); - } - } else if (optype == "clip" && as_relax) { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("min"); - input_types.push_back("max"); - } - } else if (optype == "full" && as_relax) { - input_types.push_back("shape"); - input_types.push_back("input"); - } else if (optype == "strided_slice") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("axes"); - input_types.push_back("begin"); - input_types.push_back("end"); - input_types.push_back("strides"); - } - } else if (optype == "triu") { - input_types.push_back("input"); - input_types.push_back("k"); - } else if (optype == "tril" || optype == "trilu") { - input_types.push_back("input"); - input_types.push_back("k"); - } else if (optype == "image.resize2d" && as_relax) { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("size"); - } - } else if (optype == "nn.conv1d" || optype == "nn.conv2d" || optype == "nn.conv3d") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("weight"); - } - } else if (optype == "nn.batch_norm") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("gamma"); - input_types.push_back("beta"); - input_types.push_back("mean"); - input_types.push_back("var"); - } - } else if (optype == "nn.layer_norm" || optype == "nn.group_norm") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("gamma"); - input_types.push_back("beta"); - } - } else if (optype == "msc.linear") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("weight"); - } - } else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("weight"); - input_types.push_back("bias"); - } - if (as_relax && inputs_num > 3) { - input_types.push_back("expand_bias"); - } - } else if (optype == "msc.linear_bias") { - input_types.push_back("input"); - if (inputs_num > 1) { - input_types.push_back("weight"); - input_types.push_back("bias"); - } - } else if (optype == "msc.embedding" && inputs_num == 2) { - input_types.push_back("input"); - input_types.push_back("weight"); - } else if (optype == "msc.embedding" && inputs_num == 4) { - input_types.push_back("input"); - input_types.push_back("reduce_in"); - input_types.push_back("weight"); - input_types.push_back("expand_out"); - } else if (optype == "msc.gelu") { - input_types.push_back("input"); - input_types.push_back("factor_1"); - input_types.push_back("factor_2"); - input_types.push_back("factor_3"); - } else { - for (size_t i = 0; i < inputs_num; i++) { - input_types.push_back("input"); - } - } - TVM_FFI_ICHECK_EQ(input_types.size(), inputs_num) - << "Optype " << optype << " get input types " << input_types << " and inputs_num " - << inputs_num << " mismatch"; - return input_types; -} - -const ffi::Array ExprUtils::GetInputTypes(const Call& call) { - const ffi::String& optype = StringUtils::Replace(Downcast(call->op)->name, "relax.", ""); - return GetInputTypes(optype, call->args.size(), true); -} - -const ffi::String ExprUtils::GetSpanName(const Expr& expr, const ffi::String& suffix) { - const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); - if (suffix.size() > 0) { - return name + "_" + suffix; - } - return name; -} - -const ffi::Array ExprUtils::GetShape(const TensorStructInfo& sinfo, bool as_int) { - const auto& shape_opt = sinfo->GetShape(); - if (!shape_opt.defined()) { - return ffi::Array(); - } - if (as_int) { - ffi::Array shape; - for (const auto& s : shape_opt.value()) { - shape.push_back(s->IsInstance() ? s : Integer(-1)); - } - return shape; - } - return shape_opt.value(); -} - -const ffi::Array ExprUtils::GetShape(const Expr& expr, bool as_int) { - return GetShape(Downcast(GetStructInfo(expr)), as_int); -} - -const DataType ExprUtils::GetDataType(const Expr& expr) { - return Downcast(GetStructInfo(expr))->dtype; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.core.SpanGetAttr", SpanUtils::GetAttr) - .def("msc.core.SpanGetAttrs", SpanUtils::GetAttrs) - .def("msc.core.SpanCreateWithAttr", - [](const ffi::String& key, const ffi::String& value) -> Span { - return SpanUtils::CreateWithAttr(key, value); - }) - .def("msc.core.SpanSetAttr", - [](const Span& span, const ffi::String& key, const ffi::String& value) -> Span { - return SpanUtils::SetAttr(span, key, value); - }) - .def("msc.core.CompareVersion", - [](const ffi::Array& given_version, - const ffi::Array& target_version) -> Integer { - return Integer(CommonUtils::CompareVersion(given_version, target_version)); - }) - .def("msc.core.ToAttrKey", - [](const ffi::String& key) -> ffi::String { return CommonUtils::ToAttrKey(key); }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h deleted file mode 100644 index de6294bb45be..000000000000 --- a/src/contrib/msc/core/utils.h +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/core/utils.h - * \brief Common utilities for msc. - */ -#ifndef TVM_CONTRIB_MSC_CORE_UTILS_H_ -#define TVM_CONTRIB_MSC_CORE_UTILS_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; -using Expr = tvm::RelaxExpr; - -namespace msc_attr { -/*! \brief Mark the name for the expr. */ -constexpr const char* kName = "Name"; -/*! \brief Mark the optype for the expr. */ -constexpr const char* kOptype = "Optype"; -/*! \brief Mark the optype for the expr. */ -constexpr const char* kOpattrs = "Opattrs"; -/*! \brief Mark the layout for the expr. */ -constexpr const char* kLayout = "Layout"; -/*! \brief Mark the share reference for the expr. */ -constexpr const char* kSharedRef = "SharedRef"; - -/*! \brief Mark the unique name for the func. */ -constexpr const char* kUnique = "Unique"; -/*! \brief Mark the input layout for the func. */ -constexpr const char* kInputLayouts = "InputLayouts"; -/*! \brief Mark the consumer type for the func. */ -constexpr const char* kConsumerType = "ConsumerType"; -} // namespace msc_attr - -/*! - * \brief Utils for Common. - */ -class CommonUtils { - public: - /*! - * \brief Check if the index is in range. - * \return The valid index. - */ - TVM_DLL static size_t GetIndex(int index, size_t max_size); - - /*! - * \brief Check if the index is in range. - * \return The valid indices. - */ - TVM_DLL static std::vector GetIndices(const std::vector& indices, size_t max_size); - - /*! - * \brief Compare version with version in config - * 0 for same version, 1 for greater version, -1 for less version - */ - TVM_DLL static int CompareVersion(const std::vector& given_version, - const std::vector& target_version); - TVM_DLL static int CompareVersion(const ffi::Array& given_version, - const ffi::Array& target_version); - /*! - * \brief Get attr key. - * \return The attr key. - */ - TVM_DLL static const ffi::String ToAttrKey(const ffi::String& key); -}; - -/*! - * \brief Utils for String. - */ -class StringUtils { - public: - /*! - * \brief Check if the ffi::String contains a substring. - * \return Whether substring is contained. - */ - TVM_DLL static bool Contains(const ffi::String& src_string, const ffi::String& sub_string); - - /*! - * \brief Check if the ffi::String starts with a substring. - * \return Whether string starts with substring. - */ - TVM_DLL static bool StartsWith(const ffi::String& src_string, const ffi::String& sub_string); - - /*! - * \brief Check if the ffi::String ens with a substring. - * \return Whether string endswith substring. - */ - TVM_DLL static bool EndsWith(const ffi::String& src_string, const ffi::String& sub_string); - - /*! - * \brief Split the ffi::String into sub Strings. - * \return The SubStrings. - */ - TVM_DLL static const ffi::Array Split(const ffi::String& src_string, - const ffi::String& sep); - - /*! - * \brief Join the SubStrings into String. - * \return The String. - */ - TVM_DLL static const ffi::String Join(const ffi::Array& sub_strings, - const ffi::String& joint); - TVM_DLL static const ffi::String Join(const std::vector& sub_strings, - const std::string& joint); - - /*! - * \brief Replace the substring old to new in String. - * \return The replaced String. - */ - TVM_DLL static const ffi::String Replace(const ffi::String& src_string, - const ffi::String& old_str, const ffi::String& new_str); - - /*! - * \brief Split the ffi::String into two sub Strings, only split by the frist seq. - * \return The SubStrings. - */ - TVM_DLL static const std::tuple SplitOnce(const ffi::String& src_string, - const ffi::String& sep, - bool from_left = false); - - /*! - * \brief Get the tokens between left and right. - * \return The Tokens. - */ - TVM_DLL static const ffi::Array GetClosures(const ffi::String& src_string, - const ffi::String& left, - const ffi::String& right); - - /*! - * \brief Get the first token between left and right. - * \return The Token. - */ - TVM_DLL static const ffi::String GetClosureOnce(const ffi::String& src_string, - const ffi::String& left, const ffi::String& right, - bool from_left = true); - - /*! - * \brief Change string to upper. - * \return The String. - */ - TVM_DLL static const ffi::String Upper(const ffi::String& src_string); - - /*! - * \brief Change string to lower. - * \return The String. - */ - TVM_DLL static const ffi::String Lower(const ffi::String& src_string); - - /*! - * \brief Change Object to String. - * \return The String. - */ - TVM_DLL static const ffi::String ToString(const ffi::Any& obj); -}; - -/*! - * \brief Utils for Array. - */ -class ArrayUtils { - public: - /*! - * \brief Replace the element old to new in Array. - * \return The replaced Array. - */ - template - TVM_DLL static const ffi::Array Replace(const ffi::Array& src_array, const T& old_ele, - const T& new_ele) { - ffi::Array new_array; - for (const auto& a : src_array) { - if (a == old_ele) { - new_array.push_back(new_ele); - } else { - new_array.push_back(a); - } - } - return new_array; - } - - /*! - * \brief Find the index of element. - * \return The index, -1 if not found. - */ - template - TVM_DLL static int IndexOf(const std::vector& array, const T& ele) { - for (size_t i = 0; i < array.size(); i++) { - if (array[i] == ele) { - return i; - } - } - return -1; - } - - /*! - * \brief Downcast elements in the array. - * \return The downcasted array - */ - template - TVM_DLL static const ffi::Array Cast(const ffi::Array& src_array) { - ffi::Array new_array; - for (const auto& s : src_array) { - new_array.push_back(Downcast(s)); - } - return new_array; - } - - /*! - * \brief Product elements in the arrays. - * \return The producted array - */ - template - TVM_DLL static const ffi::Array> Product(const ffi::Array>& arrays) { - ffi::Array> p_arrays; - if (arrays.size() == 1) { - for (const auto& a : arrays[0]) { - p_arrays.push_back(ffi::Array{a}); - } - return p_arrays; - } - ffi::Array> sub_arrays; - for (size_t i = 0; i < arrays.size() - 1; i++) { - sub_arrays.push_back(arrays[i]); - } - for (const auto& p_array : Product(sub_arrays)) { - for (const auto& a : arrays[arrays.size() - 1]) { - ffi::Array sub_array = p_array; - sub_array.push_back(a); - p_arrays.push_back(sub_array); - } - } - return p_arrays; - } - - /*! - * \brief Compare ffi::String arrays. - * \return Whether two array are same. - */ - TVM_DLL static bool CompareArrays(const ffi::Array& left, - const ffi::Array& right, int size = -1); - /*! - * \brief Accumulate array. - * \return The accumulate result - */ - TVM_DLL static PrimExpr Accumulate(const ffi::Array& array, int pos = -1); - - /*! - * \brief Check if lhs array is broadcastable to rhs. - * \return broadcastable - */ - TVM_DLL static bool Broadcastable(const ffi::Array& lhs, - const ffi::Array& rhs); -}; - -/*! - * \brief Utils for Span. - */ -class SpanUtils { - public: - /*! - * \brief Set value to the Span. - * \return The new Span. - */ - TVM_DLL static const Span SetAttr(const Span& span, const ffi::String& key, - const ffi::String& value); - - /*! - * \brief Get the value in value from the Span. - * \return The value String. - */ - TVM_DLL static ffi::String GetAttr(const Span& span, const ffi::String& key); - - /*! - * \brief Get all the key:value in format value from the Span. - * \return The Attrs Map. - */ - TVM_DLL static const ffi::Map GetAttrs(const Span& span); - - /*! - * \brief Create a span with value. - * \return The created Span. - */ - TVM_DLL static const Span CreateWithAttr(const ffi::String& key, const ffi::String& value); -}; - -/*! - * \brief Utils for Expr. - */ -class ExprUtils { - public: - /*! - * \brief Get the input types of call. - * \return The input types. - */ - TVM_DLL static const ffi::Array GetInputTypes(const ffi::String& optype, - size_t inputs_num, bool as_relax); - - /*! - * \brief Get the input types of call. - * \return The input types. - */ - TVM_DLL static const ffi::Array GetInputTypes(const Call& call); - - /*! - * \brief Get the scalar value of ndarray. - * \return The scalar value. - */ - template - TVM_DLL static const T GetScalar(const runtime::Tensor& array, size_t i = 0) { - if (array->dtype.code == kDLInt) { - if (array->dtype.bits == 8) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 16) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 32) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 64) { - return T(reinterpret_cast(array->data)[i]); - } - } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 1) { // bool - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 8) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 16) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 32) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 64) { - return T(reinterpret_cast(array->data)[i]); - } - } else if (array->dtype.code == kDLFloat) { - if (array->dtype.bits == 32) { - return T(reinterpret_cast(array->data)[i]); - } else if (array->dtype.bits == 64) { - return T(reinterpret_cast(array->data)[i]); - } - } - TVM_FFI_THROW(InternalError) << "Failed to get scalar from array " << array; - } - - /*! - * \brief Get the scalar value of relax constant. - * \return The scalar value. - */ - template - TVM_DLL static const T GetScalar(const Constant& constant, size_t i = 0) { - return GetScalar(constant->data, i); - } - - /*! - * \brief Get name in span. - * \return The name. - */ - TVM_DLL static const ffi::String GetSpanName(const Expr& expr, const ffi::String& suffix = ""); - - /*! - * \brief Get shape of expr. - * \return The shape. - */ - TVM_DLL static const ffi::Array GetShape(const TensorStructInfo& sinfo, - bool as_int = true); - TVM_DLL static const ffi::Array GetShape(const Expr& expr, bool as_int = true); - - /*! - * \brief Get dtype of expr. - * \return The shape. - */ - TVM_DLL static const DataType GetDataType(const Expr& expr); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_CORE_UTILS_H_ diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc deleted file mode 100644 index 45b7bdbc341b..000000000000 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/codegen.cc - */ -#include "codegen.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -void TensorflowCodeGen::CodeGenHeader() { - PyCodeGen::CodeGenHeader(); - stack_.line("from tensorflow.python import ops") - .line("from tvm.contrib.msc.framework.tensorflow import tf_v1"); -} - -void TensorflowCodeGen::CodeGenHelper() { - PyCodeGen::CodeGenHelper(); - stack_.func_def("get_variable", TensorType()) - .func_arg("name", "str") - .func_arg("shape", "List[int]") - .func_arg("dtype", "str") - .func_arg("weights", "Dict[str, tvm.runtime.Tensor]") - .func_start() - .cond_if("name in weights") - .func_call("tf_v1.get_variable", "var") - .call_arg("name") - .inplace_start("asnumpy", DocUtils::ToDoc("initializer"), - DocUtils::ToIndex("weights", "name")) - .inplace_end() - .cond_else() - .func_call("tf_v1.get_variable", "var") - .call_arg("name") - .call_arg("shape") - .call_arg("dtype") - .cond_end() - .func_end("var"); -} - -void TensorflowCodeGen::CodeGenGraph() { - stack_.func_def(graph()->name, "List[tf_v1.Tensor]"); - for (const auto& i : graph()->GetInputs()) { - const auto& pair = graph()->FindProducerAndIdx(i); - stack_.func_arg(IdxOutputBase(pair.first, pair.second), "tf_v1.Tensor"); - } - stack_.func_arg("weights", "Dict[str, tvm.runtime.Tensor]").func_start(); - // define weights - stack_.comment("Define the weights"); - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (node->optype == "nn.batch_norm") { - continue; - } - for (const auto& pair : node->weights) { - stack_.func_call("get_variable", IdxWeightBase(node, pair.first, false)) - .call_arg(DocUtils::ToStr(pair.second->name)) - .call_arg(DocUtils::ToList(pair.second->shape, true)) - .call_arg(DocUtils::ToStr(pair.second->DTypeName())) - .call_arg("weights"); - } - } - // define ops - stack_.comment("Define the ops"); - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (node->optype == "input") { - continue; - } - CodeGenNode(node, config()->use_tools); - } - ffi::Array idx_outputs; - for (const auto& o : graph()->GetOutputs()) { - const auto& pair = graph()->FindProducerAndIdx(o); - idx_outputs.push_back(IdxOutputBase(pair.first, pair.second)); - } - if (idx_outputs.size() == 1) { - stack_.assign("outputs", idx_outputs[0]); - } else { - stack_.assign("outputs", DocUtils::ToList(idx_outputs)); - } - stack_.func_end("outputs"); -} - -void TensorflowCodeGen::CodeGenInference() { - stack_.comment("Load weights") - .scope_start("open(\"" + graph()->name + "_params.bin\", \"rb\")", "f") - .func_call("tvm.runtime.load_param_dict", "params") - .func_call("read", "", "f") - .pop_nest() - .scope_end() - .comment("Build Graph") - .scope_start("tf_v1.Graph().as_default()"); - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.func_call("tf_v1.placeholder", IdxNodeBase(producer)) - .call_arg(DocUtils::ToStr(i->DTypeName())) - .call_arg(DocUtils::ToList(i->shape)) - .call_arg(DocUtils::ToStr(i->alias)); - } - stack_.func_call(graph()->name, "outs"); - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.call_arg(IdxNodeBase(producer)); - } - stack_.call_arg("params").assign("feed_dict", "{}"); - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.assign(DocUtils::ToIndex("feed_dict", IdxNodeBase(producer)), - DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); - } - stack_.scope_start("tf_v1.Session()", "sess") - .func_call("run", "", "sess") - .func_call("ops.variables.global_variables_initializer") - .pop_nest() - .func_call("run", "outputs", "sess") - .call_arg("outs") - .call_arg("feed_dict", "feed_dict") - .scope_end() - .scope_end(); -} - -const ffi::Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { - const auto& ops_map = GetTFV1OpCodes(); - auto it = ops_map->find(node->optype); - TVM_FFI_ICHECK(it != ops_map->end()) - << "Unsupported tensorflow op(" << node->optype << "): " << node; - it->second->Config(node, config(), prims()); - try { - return it->second->GetDocs(); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); - throw err; - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.framework.tensorflow.GetTensorflowSources", - [](const MSCGraph& graph, const ffi::String& codegen_config, - const ffi::String& print_config) -> ffi::Map { - TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tensorflow/codegen.h b/src/contrib/msc/framework/tensorflow/codegen.h deleted file mode 100644 index 5052c11004d2..000000000000 --- a/src/contrib/msc/framework/tensorflow/codegen.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/codegen.h - * \brief Tensorflow codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/py_codegen.h" -#include "codegen_utils.h" -#include "tf_v1_opcode.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TensorflowCodeGen : public PyCodeGen { - public: - /*! - * \brief The constructor of TensorflowCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit TensorflowCodeGen(const MSCGraph& graph, const std::string& config = "") - : PyCodeGen(graph, config) {} - - protected: - /*! \brief Stack the docs for the header*/ - void CodeGenHeader() final; - - /*! \brief Stack the docs for the helpers*/ - void CodeGenHelper() final; - - /*! \brief Stack the docs for the graph*/ - void CodeGenGraph() final; - - /*! \brief Stack the docs for the graph inference*/ - void CodeGenInference() final; - - /*! \brief Get the docs for the op*/ - const ffi::Array GetOpCodes(const MSCJoint& node) final; - - /*! \brief Get tensor type of the framework*/ - const ffi::String TensorType() const final { return "tf_v1.Tensor"; } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_H_ diff --git a/src/contrib/msc/framework/tensorflow/codegen_utils.h b/src/contrib/msc/framework/tensorflow/codegen_utils.h deleted file mode 100644 index bdce8dc0e363..000000000000 --- a/src/contrib/msc/framework/tensorflow/codegen_utils.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/codegen_utils.h - * \brief Utils for tensorflow codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_UTILS_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen helper for tensorrt codegen - */ -class TFV1CodeGenHelper : public BaseCodeGenHelper {}; - -/*! - * \brief CodeGen config for tensorflow codegen - */ -struct TensorflowCodeGenConfig { - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { CODEGEN_CONFIG_PARSE } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/framework/tensorflow/config.h b/src/contrib/msc/framework/tensorflow/config.h deleted file mode 100644 index 0a1497a1d1d9..000000000000 --- a/src/contrib/msc/framework/tensorflow/config.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/config.h - * \brief Tensorflow config for codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CONFIG_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CONFIG_H_ - -#include - -#include "../../core/codegen/base_codegen.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for tensorflow codegen - */ -struct TensorflowCodeGenConfig { - bool is_training{false}; - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("is_training")); it != obj.end()) { - is_training = (*it).second.cast(); - } - CODEGEN_CONFIG_PARSE - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_CONFIG_H_ diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc deleted file mode 100644 index 5a603454ae1a..000000000000 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc - */ -#include "tf_v1_opcode.h" - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::Array TFV1OpCode::GetDocs() { - stack_.Config(this); - CodeGenBuild(); - return stack_.GetDocs(); -} - -const std::pair> TFV1OpCode::GetPadding( - const ffi::String& strides_key, const ffi::String& kernel_key, const ffi::String& padding_key) { - ffi::String pad_mod = ""; - ffi::Array padding; - std::vector kernel_size; - if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias") { - const auto& weight = node()->WeightAt("weight"); - kernel_size.push_back(weight->DimAt("H")->value); - kernel_size.push_back(weight->DimAt("W")->value); - } else if (node()->optype == "nn.avg_pool2d" || node()->optype == "nn.max_pool2d") { - TVM_FFI_ICHECK(node()->GetAttr(kernel_key, &kernel_size)); - } else { - LOG_FATAL << "Unexpected padding node" << node(); - } - const auto& strides = node()->GetTypeArrayAttr(strides_key); - int64_t in_height = node()->InputAt(0)->DimAt("H")->value; - int64_t in_width = node()->InputAt(0)->DimAt("W")->value; - int64_t out_height = node()->OutputAt(0)->DimAt("H")->value; - int64_t out_width = node()->OutputAt(0)->DimAt("W")->value; - int64_t same_height = in_height / strides[0] + (in_height % strides[0] == 0 ? 0 : 1); - int64_t same_width = in_width / strides[1] + (in_width % strides[1] == 0 ? 0 : 1); - int64_t valid_height = (in_height - kernel_size[0] + 1) / strides[0]; - valid_height += (valid_height % strides[0] == 0 ? 0 : 1); - int64_t valid_width = (in_width - kernel_size[1] + 1) / strides[1]; - valid_width += (valid_width % strides[1] == 0 ? 0 : 1); - if (same_height == out_height && same_width == out_width) { - pad_mod = "SAME"; - } else if (valid_height == out_height && valid_width == out_width) { - pad_mod = "VALID"; - } else { - const auto& src_padding = node()->GetTypeArrayAttr(padding_key); - if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias" || - node()->optype == "nn.avg_pool2d" || node()->optype == "nn.max_pool2d") { - const auto& out_layout = node()->GetTypeAttr("out_layout"); - if (out_layout == "NHWC") { - padding.push_back("[0, 0]"); - } else if (out_layout == "NCHW") { - padding.push_back("[0, 0]"); - padding.push_back("[0, 0]"); - } else { - LOG_FATAL << "Unexpected layout for padding node" << node(); - } - if (src_padding.size() == 4) { - padding.push_back("[" + std::to_string(src_padding[0]) + ", " + - std::to_string(src_padding[2]) + "]"); - padding.push_back("[" + std::to_string(src_padding[1]) + ", " + - std::to_string(src_padding[3]) + "]"); - } else { - LOG_FATAL << "nn.conv2d/pool2d with unexpected padding " << node(); - } - if (out_layout == "NHWC") { - padding.push_back("[0, 0]"); - } - } else { - LOG_FATAL << "Unexpected padding node" << node(); - } - } - return std::make_pair(pad_mod, padding); -} - -#define TFV1_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const ffi::String& func_name) : TFV1OpCode(func_name) {} - -class TFV1ArgMaxMinCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ArgMaxMinCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_arg("axis") - .op_dtype_arg(node()->OutputAt(0)->dtype, "output_type") - .op_name_arg(); - } -}; - -class TFV1AstypeCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1AstypeCodeGen) - - protected: - void CodeGenBuild() final { - if (node()->InputAt(0)->dtype == node()->OutputAt(0)->dtype) { - stack_.op_call("tf_v1.identity").op_input_arg().op_name_arg(); - } else { - stack_.op_call().op_input_arg().op_dtype_arg(node()->OutputAt(0)->dtype).op_name_arg(); - } - } -}; - -class TFV1AxesCodeGen : public TFV1OpCode { - public: - TFV1AxesCodeGen(const ffi::String& func_name, const ffi::String& attr_name) - : TFV1OpCode(func_name) { - attr_name_ = attr_name; - } - - protected: - void CodeGenBuild() final { - const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_call().op_input_arg().op_list_arg(key, attr_name_).op_name_arg(); - } - - private: - ffi::String attr_name_; -}; - -class TFV1AxisCodeGen : public TFV1OpCode { - public: - TFV1AxisCodeGen(const ffi::String& func_name, const ffi::String& attr_name) - : TFV1OpCode(func_name) { - attr_name_ = attr_name; - } - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_arg("axis", attr_name_).op_name_arg(); - } - - private: - ffi::String attr_name_; -}; - -class TFV1BatchnormCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1BatchnormCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_arg("scale") - .op_arg("center") - .op_arg("momentum") - .op_arg("epsilon"); - ffi::Array weight_names{"gamma", "beta", "mean", "var"}; - ffi::Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; - for (size_t i = 0; i < weight_names.size(); i++) { - const auto& w_doc = DocUtils::ToStr(node()->WeightAt(weight_names[i])->name); - stack_.inplace_start("tf_v1.constant_initializer", init_names[i] + "_initializer") - .inplace_start("asnumpy", std::nullopt, DocUtils::ToIndex("weights", w_doc)) - .inplace_end() - .inplace_end(); - } - stack_.op_name_arg(); - } -}; - -class TFV1BroadcastToCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1BroadcastToCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_list_arg("shape").op_name_arg(); - } -}; - -class TFV1ClipCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ClipCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_arg("min", "clip_value_min") - .op_arg("max", "clip_value_max") - .op_name_arg(); - } -}; - -class TFV1ConcatCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ConcatCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg().op_arg("axis").op_name_arg(); } -}; - -class TFV1ConstantCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ConstantCodeGen) - - protected: - void CodeGenBuild() final { stack_.assign(IdxNode(), IdxWeight("const")); } -}; - -class TFV1ConvCodeGen : public TFV1OpCode { - public: - TFV1ConvCodeGen(const ffi::String& func_name, bool use_bias) : TFV1OpCode(func_name) { - use_bias_ = use_bias; - } - - protected: - void CodeGenBuild() final { - const auto& pair = GetPadding("strides"); - const auto& out_layout = node()->GetTypeAttr("out_layout"); - int64_t groups = node()->GetTypeAttr("groups"); - std::vector strides, dilation; - const auto& attr_strides = node()->GetTypeArrayAttr("strides"); - const auto& attr_dilation = node()->GetTypeArrayAttr("dilation"); - if (out_layout == "NHWC") { - strides = {1, attr_strides[0], attr_strides[1], 1}; - dilation = {1, attr_dilation[0], attr_dilation[1], 1}; - } else if (out_layout == "NCHW") { - strides = {1, 1, attr_strides[0], attr_strides[1]}; - dilation = {1, 1, attr_dilation[0], attr_dilation[1]}; - } else { - LOG_FATAL << "Unexpected layout for padding node" << node(); - } - if (groups == 1) { - stack_.op_call(); - } else if (groups == node()->InputAt(0)->DimAt("C")->value) { - stack_.op_call("ops.nn_ops.depthwise_conv2d_native"); - } else { - LOG_FATAL << "Unexpected conv with groups " << node(); - } - stack_.op_input_arg() - .op_weight_arg("weight") - .call_arg(DocUtils::ToList(strides), "strides") - .call_arg(DocUtils::ToList(dilation), "dilations") - .op_str_arg("data_layout", "data_format"); - if (pair.first.size() > 0) { - stack_.call_arg(DocUtils::ToStr(pair.first), "padding"); - } else if (pair.second.size() > 0) { - stack_.call_arg(DocUtils::ToList(pair.second), "padding"); - } else { - LOG_FATAL << "Can not parse padding for " << node(); - } - stack_.op_name_arg(); - if (use_bias_) { - stack_.op_call("ops.nn_ops.bias_add") - .op_output_arg() - .op_weight_arg("bias") - .op_name_arg("name", node()->name + "_bias"); - } - } - - private: - bool use_bias_; -}; - -class TFV1CreateLikeCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1CreateLikeCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_input_arg().op_str_arg("dtype").op_name_arg(); } -}; - -class TFV1EinsumCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1EinsumCodeGen) - - protected: - void CodeGenBuild() final { - const auto& producer = node()->ProducerOf(0); - stack_.op_call().op_str_arg("subscripts", ""); - if (node()->inputs.size() == 1 && producer->optype == "tuple") { - stack_.call_arg(DocUtils::ToIndex(IdxInput(), 0)); - } else { - stack_.op_inputs_arg(false); - } - stack_.op_name_arg(); - } -}; - -class TFV1FullCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1FullCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_list_arg("shape", "").op_input_arg(0, "value").op_name_arg(); - } -}; - -class TFV1GetItemCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1GetItemCodeGen) - - protected: - void CodeGenBuild() final { - stack_.assign(IdxNode(), IdxInput(node()->GetTypeAttr("index"))); - } -}; - -class TFV1PadCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1PadCodeGen) - - protected: - void CodeGenBuild() final { - ffi::String mode; - const auto& attr_mode = node()->GetTypeAttr("pad_mode"); - if (attr_mode == "constant") { - mode = "CONSTANT"; - } else { - LOG_FATAL << "Unexpected pad mode " << node(); - } - ffi::Array pad_width; - const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); - TVM_FFI_ICHECK(attr_pad_width.size() % 2 == 0) - << "pad_width should be multiple of 2, get " << node(); - for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; - pad_width.push_back(cur_pad); - } - const auto& val_producer = node()->ProducerOf(1); - TVM_FFI_ICHECK(val_producer->optype == "constant" && val_producer->HasAttr("scalar")); - stack_.op_call() - .op_input_arg() - .call_arg(DocUtils::ToList(pad_width), "paddings") - .call_arg(DocUtils::ToStr(mode), "mode") - .call_arg(val_producer->GetTypeAttr("scalar"), "constant_values") - .op_name_arg(); - } -}; - -class TFV1Pool2dCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1Pool2dCodeGen) - - protected: - void CodeGenBuild() final { - ffi::String pooling_type; - if (node()->optype == "nn.avg_pool2d") { - pooling_type = "AVG"; - } else if (node()->optype == "nn.max_pool2d") { - pooling_type = "MAX"; - } else { - LOG_FATAL << "Unexpected pool2d node " << node(); - } - const auto& pair = GetPadding("strides", "pool_size"); - stack_.op_call() - .op_input_arg() - .op_list_arg("pool_size", "window_shape") - .call_arg(DocUtils::ToStr(pooling_type), "pooling_type") - .op_list_arg("dilation", "dilation_rate") - .op_list_arg("strides"); - if (pair.first.size() > 0) { - stack_.call_arg(DocUtils::ToStr(pair.first), "padding"); - } else if (pair.second.size() > 0) { - stack_.call_arg(DocUtils::ToList(pair.second), "padding"); - } else { - LOG_FATAL << "Can not parse padding for " << node(); - } - stack_.op_name_arg(); - } -}; - -class TFV1PermuteDimsCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1PermuteDimsCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { - axes.push_back(i - 1); - } - } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes)).op_name_arg(); - } -}; - -class TFV1ReduceAxisCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ReduceAxisCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_list_arg("axis").op_arg("keepdims").op_name_arg(); - } -}; - -class TFV1ReshapeCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1ReshapeCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_list_arg("shape").op_name_arg(); - } -}; - -class TFV1Resize2dCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1Resize2dCodeGen) - - protected: - void CodeGenBuild() final { - ffi::String func_name; - const auto& method = node()->GetTypeAttr("method"); - const auto& coordinate_transformation_mode = - node()->GetTypeAttr("coordinate_transformation_mode"); - bool align_corners = coordinate_transformation_mode == "align_corners"; - if (method == "linear") { - func_name = "tf_v1.image.resize_bilinear"; - } else if (method == "nearest_neighbor") { - func_name = "tf_v1.image.resize_nearest_neighbor"; - } else { - LOG_FATAL << "Unexpected resize with method " << node(); - } - stack_.op_call(func_name) - .op_input_arg() - .op_list_arg("size") - .call_arg(align_corners, "align_corners") - .op_name_arg(); - } -}; - -class TFV1SimpleCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1SimpleCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_name_arg(); } -}; - -class TFV1SplitCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1SplitCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - std::vector indices; - int axis = node()->GetTypeAttr("axis"); - for (size_t i = 0; i < node()->outputs.size(); i++) { - indices.push_back(node()->OutputAt(i)->DimAt(axis)->value); - } - stack_.call_arg(DocUtils::ToList(indices), "num_or_size_splits") - .op_arg("axis") - .op_name_arg(); - } -}; - -class TFV1StridedSliceCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1StridedSliceCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - axes.push_back(i); - } - } - stack_.op_call() - .op_input_arg() - .op_list_arg("begin") - .op_list_arg("end") - .op_list_arg("strides") - .op_name_arg(); - } -}; - -class TFV1TakeCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1TakeCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg(false).op_arg("axis").op_name_arg(); - } -}; - -class TFV1TileCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1TileCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_list_arg("repeats", "multiples").op_name_arg(); - } -}; - -class TFV1TupleCodeGen : public TFV1OpCode { - TFV1_OP_CODEGEN_METHODS(TFV1TupleCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(); } -}; - -const std::shared_ptr>> -GetTFV1OpCodes() { - static auto map = - std::make_shared>>(); - if (!map->empty()) return map; - // binary && unary ops - map->emplace("abs", std::make_shared("tf_v1.abs")); - map->emplace("acos", std::make_shared("tf_v1.acos")); - map->emplace("acosh", std::make_shared("tf_v1.acosh")); - map->emplace("add", std::make_shared("tf_v1.add")); - map->emplace("asin", std::make_shared("tf_v1.asin")); - map->emplace("asinh", std::make_shared("tf_v1.asinh")); - map->emplace("atanh", std::make_shared("tf_v1.atanh")); - map->emplace("atan", std::make_shared("tf_v1.atan")); - map->emplace("ceil", std::make_shared("tf_v1.ceil")); - map->emplace("cos", std::make_shared("tf_v1.cos")); - map->emplace("cosh", std::make_shared("tf_v1.cosh")); - map->emplace("divide", std::make_shared("tf_v1.divide")); - map->emplace("equal", std::make_shared("tf_v1.equal")); - map->emplace("erf", std::make_shared("tf_v1.erf")); - map->emplace("exp", std::make_shared("tf_v1.exp")); - map->emplace("floor", std::make_shared("tf_v1.floor")); - map->emplace("floor_divide", std::make_shared("tf_v1.floor_div")); - map->emplace("floor_mod", std::make_shared("tf_v1.floormod")); - map->emplace("greater", std::make_shared("tf_v1.greater")); - map->emplace("greater_equal", std::make_shared("tf_v1.greater_equal")); - map->emplace("isfinite", std::make_shared("tf_v1.is_finite")); - map->emplace("isinf", std::make_shared("tf_v1.is_inf")); - map->emplace("isnan", std::make_shared("tf_v1.is_nan")); - map->emplace("less", std::make_shared("tf_v1.less")); - map->emplace("less_equal", std::make_shared("tf_v1.less_equal")); - map->emplace("log", std::make_shared("tf_v1.log")); - map->emplace("log1p", std::make_shared("tf_v1.log1p")); - map->emplace("logical_and", std::make_shared("tf_v1.logical_and")); - map->emplace("logical_or", std::make_shared("tf_v1.logical_or")); - map->emplace("logical_xor", std::make_shared("tf_v1.logical_xor")); - map->emplace("logical_not", std::make_shared("tf_v1.logical_not")); - map->emplace("maximum", std::make_shared("tf_v1.maximum")); - map->emplace("minimum", std::make_shared("tf_v1.minimum")); - map->emplace("multiply", std::make_shared("tf_v1.multiply")); - map->emplace("negative", std::make_shared("tf_v1.negative")); - map->emplace("not_equal", std::make_shared("tf_v1.not_equal")); - map->emplace("power", std::make_shared("tf_v1.pow")); - map->emplace("round", std::make_shared("tf_v1.round")); - map->emplace("rsqrt", std::make_shared("tf_v1.rsqrt")); - map->emplace("sigmoid", std::make_shared("ops.math_ops.sigmoid")); - map->emplace("sign", std::make_shared("tf_v1.sign")); - map->emplace("sin", std::make_shared("tf_v1.sin")); - map->emplace("sinh", std::make_shared("tf_v1.sinh")); - map->emplace("sqrt", std::make_shared("tf_v1.sqrt")); - map->emplace("subtract", std::make_shared("tf_v1.subtract")); - map->emplace("tan", std::make_shared("tf_v1.tan")); - map->emplace("tanh", std::make_shared("tf_v1.tanh")); - map->emplace("where", std::make_shared("tf_v1.where")); - - // reduce axis ops - map->emplace("max", std::make_shared("tf_v1.reduce_max")); - map->emplace("min", std::make_shared("tf_v1.reduce_min")); - map->emplace("mean", std::make_shared("tf_v1.reduce_mean")); - map->emplace("sum", std::make_shared("tf_v1.reduce_sum")); - map->emplace("prod", std::make_shared("tf_v1.reduce_prod")); - map->emplace("std", std::make_shared("tf_v1.reduce_std")); - - // create ops - map->emplace("constant", std::make_shared("get_variable")); - map->emplace("full", std::make_shared("tf_v1.fill")); - map->emplace("zeros_like", std::make_shared("tf_v1.zeros_like")); - - // axis && axes ops - map->emplace("expand_dims", std::make_shared("tf_v1.expand_dims", "axis")); - map->emplace("nn.log_softmax", std::make_shared("tf_v1.nn.log_softmax", "axis")); - map->emplace("nn.softmax", std::make_shared("tf_v1.nn.softmax", "axis")); - map->emplace("squeeze", std::make_shared("ops.array_ops.squeeze", "axis")); - - // math ops - map->emplace("argmax", std::make_shared("tf_v1.argmax")); - map->emplace("argmin", std::make_shared("tf_v1.argmin")); - map->emplace("astype", std::make_shared("tf_v1.cast")); - map->emplace("broadcast_to", std::make_shared("tf_v1.broadcast_to")); - map->emplace("clip", std::make_shared("tf_v1.clip_by_value")); - map->emplace("concat", std::make_shared("ops.array_ops.concat_v2")); - map->emplace("concatenate", std::make_shared("ops.array_ops.concat_v2")); - map->emplace("einsum", std::make_shared("tf_v1.einsum")); - map->emplace("matmul", std::make_shared("tf_v1.matmul")); - map->emplace("permute_dims", std::make_shared("tf_v1.transpose")); - map->emplace("reshape", std::make_shared("ops.array_ops.reshape")); - map->emplace("split", std::make_shared("tf_v1.split")); - map->emplace("strided_slice", std::make_shared("tf_v1.strided_slice")); - map->emplace("take", std::make_shared("tf_v1.gather")); - map->emplace("tile", std::make_shared("tf_v1.tile")); - - // nn ops - map->emplace("nn.avg_pool2d", std::make_shared("ops.nn_ops.pool")); - map->emplace("nn.batch_norm", - std::make_shared("tf_v1.layers.batch_normalization")); - map->emplace("nn.conv2d", std::make_shared("ops.nn_ops.conv2d", false)); - map->emplace("nn.max_pool2d", std::make_shared("ops.nn_ops.pool")); - map->emplace("nn.pad", std::make_shared("tf_v1.pad")); - map->emplace("nn.relu", std::make_shared("tf_v1.nn.relu")); - - // image ops - map->emplace("image.resize2d", - std::make_shared("tf_v1.image.resize_nearest_neighbor")); - - // special op - map->emplace("get_item", std::make_shared("")); - map->emplace("tuple", std::make_shared("tuple")); - - // msc ops - map->emplace("msc.conv2d", std::make_shared("ops.nn_ops.conv2d", false)); - map->emplace("msc.conv2d_bias", std::make_shared("ops.nn_ops.conv2d", true)); - - return map; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h deleted file mode 100644 index a744ffc701e4..000000000000 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.h +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorflow/tf_v1_opcode.h - * \brief Tensorflow codegen for MSCJoint, use v1 format. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_TF_V1_OPCODE_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_TF_V1_OPCODE_H_ - -#include -#include -#include -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TFV1OpCode; -typedef OpCodeStack TFV1OpCodeStack; - -/*! - * \brief CodeGen for tensorflow op - */ -class TFV1OpCode : public BaseOpCode { - public: - /*! - * \brief The constructor of BaseOpDocsifier - * \param func_name the function name for the node. - * \param config the config json for the node. - */ - explicit TFV1OpCode(const ffi::String& func_name) - : BaseOpCode(func_name) {} - - /*! \brief Convert node to docs*/ - const ffi::Array GetDocs() final; - - /*! \brief Get dtype string*/ - const ffi::String DType(const DataType& dtype) final { - return "tf_v1." + BaseOpCode::DType(dtype); - } - - protected: - TFV1OpCodeStack stack_; - - /*! \brief Convert op build*/ - virtual void CodeGenBuild() = 0; - - /*! \brief Get padding mode or array*/ - const std::pair> GetPadding( - const ffi::String& strides_key, const ffi::String& kernel_key = "", - const ffi::String& padding_key = "padding"); -}; - -/*! - * \brief Get the map of available TFV1OpCode, use optype as key - * \return Map of - */ -const std::shared_ptr>> -GetTFV1OpCodes(); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORFLOW_TF_V1_OPCODE_H_ diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc deleted file mode 100644 index 8cf746f825aa..000000000000 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ /dev/null @@ -1,634 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/codegen.cc - * \brief Codegen related classes. - */ - -#include "codegen.h" - -#include -#include -#include - -#include - -#include "../../core/codegen/codegen_json.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::relax; - -void TensorRTCodeGen::CodeGenClassDeclare() { - stack_.line("#include \"NvInfer.h\"") - .line("#include \"NvInferRuntimeCommon.h\"") - .line("#include \"utils/base.h\"") - .line("#include \"utils/trt_common.h\""); - if (config()->precision == "int8") { - stack_.line("#include \"utils/trt_quantize.h\""); - } - // plugin headers - if (config()->use_plugin) { - std::set plugins; - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (IsPlugin(node->optype) && !plugins.count(node->optype)) { - stack_.line("#include \"plugin/" + node->optype + "_op.h\""); - plugins.insert(node->optype); - } - } - } - stack_.line().line("using namespace nvinfer1;").line(); - StartNamespace(); - // start class declare - stack_.class_def(graph()->name).class_start().scope_start("public:"); - // declare build method - stack_.func_def("Build", "bool") - .func_arg("builder", "TRTPtr&") - .func_arg("network", "TRTPtr&"); - if (CompareVersion(6, 0, 0) >= 0) { - stack_.func_arg("config", "TRTPtr&"); - } - stack_.func_arg("logger", "TRTLogger&").func_start().func_end(); - // define cleanup method - stack_.func_def("CleanUp", "bool") - .func_start() - .for_start("mem", "mWeights") - .func_call("free") - .call_arg("(void*) (mem.second.values)") - .for_end() - .func_end("true"); - // end public scope - stack_.scope_end(); - // private scope - stack_.scope_start("private:").declare("std::map", "mWeights").scope_end(); - // end class declare - stack_.class_end(); - // declare test function - stack_.func_def("test_" + graph()->name, "bool") - .func_arg("engine", "std::shared_ptr&") - .func_arg("reader", "DatasetReader&") - .func_arg("logger", "TRTLogger&") - .func_start() - .func_end(); - EndNamespace(); -} - -void TensorRTCodeGen::CodeGenClassDefine() { - auto malloc_buffer = [this](const MSCTensor& tensor) { - const ffi::String& idx_var = "idx_" + IdxTensor(tensor); - this->stack_ - .func_call("getBindingIndex", DocUtils::ToDeclare("int", idx_var), - DocUtils::ToPtr("engine")) - .call_arg(DocUtils::ToStr(tensor->name)) - .func_call("CHECK") - .func_call("cudaMalloc") - .call_arg(DocUtils::ToIndex("&gpu_buffers", idx_var)) - .call_arg(GetTensorBytes(tensor)) - .pop_nest() - .func_call("malloc", DocUtils::ToIndex("cpu_buffers", idx_var)) - .call_arg(GetTensorBytes(tensor)); - }; - stack_.line("#include \"" + graph()->name + ".h\"").line(); - StartNamespace(); - // start define build method - stack_.func_def(graph()->name + "::Build", "bool") - .func_arg("builder", "TRTPtr&") - .func_arg("network", "TRTPtr&"); - if (CompareVersion(6, 0, 0) >= 0) { - stack_.func_arg("config", "TRTPtr&"); - } - stack_.func_arg("logger", "TRTLogger&").func_start(); - // save codegen before build - if (config()->use_tools) { - const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - before_build_codes_ = pf(GetStepCtx(), "before_build", graph()->name, config()->tools_tag) - .cast>(); - } - if (graph()->weight_holders.size() > 0) { - stack_.func_call("TRTUtils::LoadWeights", "mWeights") - .call_arg(DocUtils::ToStr(graph()->name + ".wts")); - } - // build layers - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - CodeGenNode(node, config()->use_tools); - } - // mark outputs - stack_.comment("Mark outputs"); - for (const auto& o : graph()->GetOutputs()) { - const auto& pair = graph()->FindProducerAndIdx(o); - stack_.func_call("markOutput", std::nullopt, DocUtils::ToPtr("network")) - .call_arg("*" + IdxOutputBase(pair.first, pair.second)); - } - // mark batch_size - stack_.comment("Mark batch size"); - stack_.func_call("createOptimizationProfile", DocUtils::ToDeclare("auto", "profile"), - DocUtils::ToPtr("builder")); - ffi::Array batch_flags{"MIN", "MAX", "OPT"}; - for (const auto& i : graph()->GetInputs()) { - for (const auto& f : batch_flags) { - stack_.func_call("setDimensions", std::nullopt, DocUtils::ToPtr("profile")) - .call_arg(DocUtils::ToStr(i->name)) - .call_arg("OptProfileSelector::k" + f) - .call_arg(ToDims(i->shape)); - } - } - // set max workspace - stack_.comment("Set max worksapce"); - if (CompareVersion(6, 0, 0) >= 0) { - stack_.func_call("setMaxWorkspaceSize", std::nullopt, DocUtils::ToPtr("config")) - .call_arg(config()->max_workspace); - } else { - stack_.func_call("setMaxWorkspaceSize", std::nullopt, DocUtils::ToPtr("builder")) - .call_arg(config()->max_workspace); - } - // set data type - if (config()->precision == "float16") { - stack_.comment("Set network precision") - .cond_if("!builder->platformHasFastFp16()") - .func_call("log", "", "logger") - .call_arg("ILogger::Severity::kINTERNAL_ERROR") - .call_arg(DocUtils::ToStr("platform do not support float16, fallback to float32")) - .cond_else() - .func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kFP16"); - if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kSTRICT_TYPES"); - } - stack_.func_call("log", "", "logger") - .call_arg("ILogger::Severity::kINFO") - .call_arg(DocUtils::ToStr("use float16 to build the engine")) - .cond_end(); - } else if (config()->precision == "int8") { - stack_.comment("Set network precision") - .cond_if("!builder->platformHasFastInt8()") - .func_call("log", "", "logger") - .call_arg("ILogger::Severity::kINTERNAL_ERROR") - .call_arg(DocUtils::ToStr("platform do not support int8, fallback to float32")) - .cond_else() - .func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kINT8"); - if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kSTRICT_TYPES"); - } else if (config()->precision_mode == "prefer") { - stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kPREFER_PRECISION_CONSTRAINTS"); - } else if (config()->precision_mode == "obey") { - stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("BuilderFlag::kOBEY_PRECISION_CONSTRAINTS"); - } - stack_.func_call("log", "", "logger") - .call_arg("ILogger::Severity::kINFO") - .call_arg(DocUtils::ToStr("use int8 to build the engine")) - .cond_end(); - } - // save codegen after build - if (config()->use_tools) { - const auto pf = tvm::ffi::Function::GetGlobalRequired("msc_tool.codegen_step"); - after_build_codes_ = pf(GetStepCtx(), "after_build", graph()->name, config()->tools_tag) - .cast>(); - } - // end define build method - stack_.func_end("true"); - // start define test function - stack_.func_def("test_" + graph()->name, "bool") - .func_arg("engine", "std::shared_ptr&") - .func_arg("reader", "DatasetReader&") - .func_arg("logger", "TRTLogger&") - .func_start(); - stack_.comment("Create context") - .func_call("TRTPtr", DocUtils::ToDeclare("auto", "context")) - .func_call("createExecutionContext", std::nullopt, DocUtils::ToPtr("engine")) - .pop_nest(); - ReturnOnFail("context", "Failed to create the context"); - // prepare variables - stack_.declare("bool", "pass", 0, false) - .declare_arg("true") - .declare("cudaStream_t", "stream") - .func_call("CHECK") - .func_call("cudaStreamCreate") - .call_arg("&stream") - .pop_nest(); - // malloc buffers - size_t binding_num = graph()->input_names.size() + graph()->output_names.size(); - stack_.comment("Malloc and copy the buffers") - .declare("void*", "cpu_buffers", binding_num) - .declare("void*", "gpu_buffers", binding_num); - for (const auto& i : graph()->GetInputs()) { - malloc_buffer(i); - } - for (const auto& o : graph()->GetOutputs()) { - malloc_buffer(o); - stack_.declare(CppDType(o->dtype), "output_" + IdxTensor(o), - static_cast(o->GetSize()->value)); - } - // read and test datas - stack_.comment("Read and test datas") - .while_start("reader.ReadNext(cpu_buffers)") - .comment("Memcopy inputs host to device"); - // copy inputs - for (const auto& i : graph()->GetInputs()) { - stack_.func_call("CHECK") - .func_call("cudaMemcpyAsync") - .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(i))) - .call_arg(DocUtils::ToIndex("cpu_buffers", "idx_" + IdxTensor(i))) - .call_arg(GetTensorBytes(i)) - .call_arg("cudaMemcpyHostToDevice") - .call_arg("stream") - .pop_nest(); - } - // enqueue - stack_.func_call("cudaStreamSynchronize") - .call_arg("stream") - .comment("enquque with gpu buffers") - .func_call("enqueueV2", std::nullopt, DocUtils::ToPtr("context")) - .call_arg("gpu_buffers") - .call_arg("stream") - .call_arg("nullptr") - .comment("Memcopy outputs device to host"); - // copy outputs - for (const auto& o : graph()->GetOutputs()) { - stack_.func_call("CHECK") - .func_call("cudaMemcpyAsync") - .call_arg("output_" + IdxTensor(o)) - .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(o))) - .call_arg(GetTensorBytes(o)) - .call_arg("cudaMemcpyDeviceToHost") - .call_arg("stream") - .pop_nest(); - } - stack_.func_call("cudaStreamSynchronize").call_arg("stream"); - // compare outputs - for (const auto& o : graph()->GetOutputs()) { - stack_.func_call("CommonUtils::CompareBuffers", "pass") - .call_arg("(" + CppDType(o->dtype) + "*)cpu_buffers[idx_" + IdxTensor(o) + "]") - .call_arg("output_" + IdxTensor(o)) - .call_arg(o->GetSize()); - ReturnOnFail("pass", "Failed to test the output " + o->name); - } - stack_.while_end(); - // clean up - stack_.comment("Clean up the buffers and stream") - .func_call("cudaStreamDestroy") - .call_arg("stream") - .for_start("i", 0, binding_num) - .func_call("CHECK") - .func_call("cudaFree") - .call_arg(DocUtils::ToIndex("gpu_buffers", "i")) - .pop_nest() - .func_call("free") - .call_arg(DocUtils::ToIndex("cpu_buffers", "i")) - .for_end(); - // end define test method - stack_.func_end("true"); - EndNamespace(); -} - -void TensorRTCodeGen::CodeGenMain() { - stack_.line("#include \"" + graph()->name + ".h\"") - .line() - .line("using namespace nvinfer1;") - .line("using namespace tvm::contrib::msc;") - .line() - .func_def("main", "int") - .func_arg("argc", "int") - .func_arg("argv", "char**") - .func_start() - .declare("TRTLogger", "logger") - .func_call("setLogSeverity", "", "logger"); - if (config()->log_level == 0) { - stack_.call_arg("ILogger::Severity::kINFO"); - } else if (config()->log_level == 1) { - stack_.call_arg("ILogger::Severity::kVERBOSE"); - } else { - stack_.call_arg("ILogger::Severity::kWARNING"); - } - // prepare for build - stack_.comment("Define arguments") - .assign("pass", "true", "bool") - .assign("repeat_num", "1000", "int") - .assign("profile_level", std::to_string(config()->profile_level), "int") - .cond_if("argc > 1") - .assign("profile_level", "atoi(argv[1])") - .cond_end(); - - // start build the engine - stack_.comment("Build engine if not exist") - .cond_if("!FileUtils::FileExist(\"" + graph()->name + ".trt\")"); - // create builder - stack_.comment("Create TensorRT tools") - .func_call("TRTPtr", DocUtils::ToDeclare("auto", "builder")) - .func_call("createInferBuilder") - .call_arg("logger") - .pop_nest(); - ReturnOnFail("builder", "Failed to create builder"); - // create network - if (CompareVersion(6, 0, 0) >= 0) { - stack_ - .assign("flags", - "1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)", - "uint32_t") - .func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) - .func_call("createNetworkV2", std::nullopt, DocUtils::ToPtr("builder")) - .call_arg("flags") - .pop_nest(); - } else { - stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) - .func_call("createNetwork", std::nullopt, DocUtils::ToPtr("builder")) - .pop_nest(); - } - ReturnOnFail("network", "Failed to create network"); - // create config - stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "config")) - .func_call("createBuilderConfig", std::nullopt, DocUtils::ToPtr("builder")) - .pop_nest(); - ReturnOnFail("config", "Failed to create config"); - // add codegen before build - for (const auto& l : before_build_codes_) { - stack_.line(l); - } - // build model - stack_.comment("Build model") - .declare(graph()->name, "model") - .func_call("Build", "pass", "model") - .call_arg("builder") - .call_arg("network"); - if (CompareVersion(6, 0, 0) >= 0) { - stack_.call_arg("config"); - } - stack_.call_arg("logger"); - ReturnOnFail("pass", "Failed to build model"); - // add codegen after build - for (const auto& l : after_build_codes_) { - stack_.line(l); - } - // Set profile flag - stack_.comment("Set profile flag") - .declare("ProfilingVerbosity", "profile_verbose") - .cond_if("profile_level == 2") - .assign("profile_verbose", "ProfilingVerbosity::kDETAILED") - .cond_else() - .cond_if("profile_level == 1") - .assign("profile_verbose", "ProfilingVerbosity::kLAYER_NAMES_ONLY") - .cond_else() - .assign("profile_verbose", "ProfilingVerbosity::kNONE") - .cond_end() - .cond_end() - .func_call("setProfilingVerbosity", std::nullopt, DocUtils::ToPtr("config")) - .call_arg("profile_verbose"); - // Serialize engine - stack_.comment("Serialize engine") - .func_call("TRTUtils::SerializeEngineToFile", "pass") - .call_arg(DocUtils::ToStr(graph()->name + ".trt")) - .call_arg("builder") - .call_arg("network"); - if (CompareVersion(6, 0, 0) >= 0) { - stack_.call_arg("config"); - } - stack_.call_arg("logger"); - ReturnOnFail("pass", "Failed to serialize the engine"); - // end build the engine - stack_.cond_end(); - // start deserialize engine - stack_.comment("Deserialize engine") - .declare("std::shared_ptr", "engine") - .func_call("TRTUtils::DeserializeEngineFromFile", "pass") - .call_arg(DocUtils::ToStr(graph()->name + ".trt")) - .call_arg("engine") - .call_arg("logger"); - ReturnOnFail("pass", "Failed to deserialize the engine"); - // dump info by inspector - stack_.comment("Dump info by inspector") - .cond_if("profile_level > 0") - .func_call("TRTPtr", DocUtils::ToDeclare("auto", "inspector")) - .func_call("createEngineInspector", std::nullopt, DocUtils::ToPtr("engine")) - .pop_nest() - .func_call("getEngineInformation", DocUtils::ToDeclare("std::string", "result"), - DocUtils::ToPtr("inspector")) - .call_arg("LayerInformationFormat::kJSON") - .declare("std::ofstream", "os") - .declare_arg(DocUtils::ToStr(graph()->name + "_info.json")) - .declare_arg("std::ofstream::trunc") - .line("os << result << std::flush;") - .cond_end(); - // test engine - if (config()->test_iter > 0) { - stack_.comment("Prepare dataset") - .declare("DatasetReader", "reader") - .declare_arg(DocUtils::ToStr(config()->dataset)) - .declare_arg(config()->test_iter); - stack_.comment("Test engine by datas") - .func_call("test_" + graph()->name, "pass") - .call_arg("engine") - .call_arg("reader") - .call_arg("logger"); - } - ReturnOnFail("pass", "Failed to test the engine"); - stack_.func_end("pass ? 0 : 1"); -} - -void TensorRTCodeGen::CodeGenCmake() { - stack_.line("cmake_minimum_required(VERSION " + config()->cmake_version + " FATAL_ERROR)") - .line("project(" + graph()->name + ")") - .line("find_package(CUDA)") - .line() - .line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + - " PATH_SUFFIXES include)") - .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + - " PATH_SUFFIXES lib)") - .line( - "message(STATUS \"Build project with TRT_INCLUDE_DIR ${TRT_INCLUDE_DIR} and " - "TRT_LIBS " - "${TRT_LIBS}\")") - .line() - .line("add_definitions(-DTRT_MAJOR=" + std::to_string(config()->version[0]) + ")") - .line("add_definitions(-DTRT_MINOR=" + std::to_string(config()->version[1]) + ")") - .line("add_definitions(-DTRT_PATCH=" + std::to_string(config()->version[2]) + ")") - .line(); - if (config()->use_plugin) { - stack_.line("add_definitions(-DPLUGIN_SUPPORT_TENSORRT)").line(); - } - ffi::String link_libs = " ${TRT_LIBS}"; - if (config()->extern_libs.size() > 0) { - stack_.line("set(EXTERN_LIBS " + StringUtils::Join(config()->extern_libs, " ") + ")"); - link_libs = link_libs + " ${EXTERN_LIBS}"; - } - stack_.line("file(GLOB_RECURSE TRT_SRCS *.cc)") - .line("cuda_add_executable(" + graph()->name + " ${TRT_SRCS})") - .line("target_include_directories(" + graph()->name + " PUBLIC ${TRT_INCLUDE_DIR})") - .line("target_link_libraries(" + graph()->name + link_libs + ")"); -} - -const ffi::String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { - const auto& pair = graph()->FindProducerAndIdx(tensor); - const ffi::String& prefix = "tensor_" + std::to_string(pair.first->index); - if (pair.first->outputs.size() > 1) { - return prefix + "_" + std::to_string(pair.second); - } - return prefix; -} - -const ffi::String TensorRTCodeGen::CppDType(const DataType& dtype) { - const ffi::String& dtype_name = - CppCodeGen::DType(dtype); - if (dtype_name == "int32") { - return "int"; - } - if (dtype_name == "int64") { - return "int64_t"; - } - if (dtype_name == "float32") { - return "float"; - } - if (dtype_name == "float64") { - return "double"; - } - return dtype_name; -} - -const ffi::String TensorRTCodeGen::GetTensorBytes(const MSCTensor& tensor) { - return std::to_string(tensor->GetSize()->value) + " * sizeof(" + CppDType(tensor->dtype) + ")"; -} - -void TensorRTCodeGen::ReturnOnFail(const ffi::String& flag, const ffi::String& err) { - stack_.cond_if("!" + flag) - .func_call("logger.log") - .call_arg("ILogger::Severity::kERROR") - .call_arg(DocUtils::ToStr(err)) - .line("return -1;") - .cond_end(); -} - -template -const ffi::String TensorRTCodeGen::ToDims(const std::vector& dims, bool use_ndim) { - if (dims.size() == 2 && !use_ndim) { - return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; - } - ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; - for (size_t i = 0; i < dims.size(); i++) { - dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); - } - dims_str = dims_str + "}})"; - return dims_str; -} - -const ffi::String TensorRTCodeGen::ToDims(const ffi::Array& dims, bool use_ndim) { - std::vector int_dims; - for (const auto& d : dims) { - int_dims.push_back(d->value); - } - return ToDims(int_dims, use_ndim); -} - -const ffi::Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { - const auto& ops_map = GetTensorRTOpCodes(); - auto it = ops_map->find(GetOpType(node)); - TVM_FFI_ICHECK(it != ops_map->end()) - << "Unsupported tensorrt op(" << node->optype << "): " << node; - it->second->Config(node, config(), prims()); - try { - return it->second->GetDocs(); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); - throw err; - } -} - -const ffi::Map TensorRTCodeGen::GetTensorCtx(const MSCTensor& tensor) { - ffi::Map tensor_ctx; - tensor_ctx.Set("ctx", "network"); - for (const auto& pair : - CppCodeGen::GetTensorCtx(tensor)) { - tensor_ctx.Set(pair.first, pair.second); - } - return tensor_ctx; -} - -const ffi::Map TensorRTCodeGen::GetStepCtx() { - ffi::Map step_ctx; - step_ctx.Set("network", "network"); - step_ctx.Set("config", "config"); - step_ctx.Set("builder", "builder"); - for (const auto& pair : CppCodeGen::GetStepCtx()) { - step_ctx.Set(pair.first, pair.second); - } - return step_ctx; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("msc.framework.tensorrt.GetTensorRTSources", - [](const MSCGraph& graph, const ffi::String& codegen_config, - const ffi::String& print_config) -> ffi::Map { - TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }) - .def("msc.framework.tensorrt.GetTensorRTRoot", []() -> ffi::String { -#ifdef TENSORRT_ROOT_DIR - return TENSORRT_ROOT_DIR; -#else - return ""; -#endif - }); -} - -/*! - * \brief Create runtime modules for MSC TensorRT. - * \param functions The extern functions to be compiled via TensorRT - * \return Runtime modules. - */ -ffi::Array MSCTensorRTCompiler(ffi::Array functions, - ffi::Map target_option, - ffi::Map constant_names) { - ffi::Array compiled_functions; - for (const auto& func : functions) { - VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); - TVM_FFI_ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; - const auto& name = name_opt.value(); - std::string func_name = GetExtSymbol(func); - TVM_FFI_ICHECK(target_option.count(name)) << "Can not find target option for " << name; - const auto& options = Downcast(target_option[name]); - MSCJSONSerializer serializer(constant_names, options); - serializer.serialize(func); - std::string graph_json = serializer.GetJSON(); - const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.msc_tensorrt_runtime_create"); - VLOG(1) << "Creating msc_tensorrt ffi::Module for '" << func_name << "'"; - compiled_functions.push_back( - pf(func_name, graph_json, serializer.GetConstantNames()).cast()); - } - return compiled_functions; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.ext.msc_tensorrt", MSCTensorRTCompiler); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/codegen.h b/src/contrib/msc/framework/tensorrt/codegen.h deleted file mode 100644 index 87b4c330e40b..000000000000 --- a/src/contrib/msc/framework/tensorrt/codegen.h +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/codegen.h - * \brief Relax codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ - -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/cpp_codegen.h" -#include "codegen_utils.h" -#include "tensorrt_opcode.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TensorRTCodeGen : public CppCodeGen { - public: - /*! - * \brief The constructor of TensorRTCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit TensorRTCodeGen(const MSCGraph& graph, const std::string& config = "") - : CppCodeGen(graph, config) {} - - /*! \brief Stack the docs for the class declare*/ - void CodeGenClassDeclare() final; - - /*! \brief Stack the docs for the class define*/ - void CodeGenClassDefine() final; - - /*! \brief Stack the docs for the main func*/ - void CodeGenMain() final; - - /*! \brief Stack the docs for the class define*/ - void CodeGenCmake() final; - - protected: - /*! \brief Get the docs for the op*/ - const ffi::Array GetOpCodes(const MSCJoint& node) final; - - /*! \brief Get the tensor context for codegen_tensor*/ - const ffi::Map GetTensorCtx(const MSCTensor& tensor) final; - - /*! \brief Get the step context for codegen_step*/ - const ffi::Map GetStepCtx() final; - - /*! \brief Generate return on fail codes*/ - void ReturnOnFail(const ffi::String& flag, const ffi::String& err); - - /*! \brief Get the index tensor*/ - const ffi::String IdxTensor(const MSCTensor& tensor); - - /*! \brief Get the dtype from the datatype*/ - const ffi::String CppDType(const DataType& dtype); - - /*! \brief Generate describe for tensor bytes*/ - const ffi::String GetTensorBytes(const MSCTensor& tensor); - - /*! \brief Get the tensorrt dims from dims*/ - template - const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); - const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); - - private: - ffi::Array before_build_codes_; - ffi::Array after_build_codes_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_H_ diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h deleted file mode 100644 index df68ec66ab1d..000000000000 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/codegen_utils.h - * \brief TensorRT config for codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ - -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen helper for tensorrt codegen - */ -class TensorRTCodeGenHelper : public BaseCodeGenHelper { - public: - /*! \brief Get describe for default node input*/ - const ffi::String IdxInputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, - const ffi::String& suffix = "", bool process = false) final { - const auto& pair = node->ProducerAndIdxOf(idx); - if (pair.first->optype == "input") { - return "*" + IdxNodeBase(pair.first, prefix, suffix); - } - if (pair.first->optype == "tuple" || pair.first->optype == "get_item") { - return "*" + IdxNodeBase(pair.first, prefix, suffix); - } - return "*" + IdxOutputBase(pair.first, prefix, pair.second, suffix); - } - - /*! \brief Get describe for default node output*/ - const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, - const ffi::String& suffix = "", bool mark_exit = false) final { - if (node->optype == "argmax" || node->optype == "argmin") { - TVM_FFI_ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; - return IdxNodeBase(node, prefix, suffix) + "->getOutput(1)"; - } - if (node->optype == "tuple") { - return IdxNodeBase(node, prefix, suffix) + "[" + std::to_string(idx) + "]"; - } - if (node->optype == "get_item") { - TVM_FFI_ICHECK_EQ(idx, 0) << "get item only has 1 output, get " << idx; - return IdxNodeBase(node, prefix, suffix); - } - return IdxNodeBase(node, prefix, suffix) + "->getOutput(" + std::to_string(idx) + ")"; - } - - /*! \brief Get describe for default node weight*/ - const ffi::String IdxWeightBase(const MSCJoint& node, const ffi::String& wtype, - const ffi::String& suffix = "", bool process = false) final { - return "mWeights[\"" + node->WeightAt(wtype)->name + "\"]"; - } -}; - -/*! - * \brief CodeGen config for tensorrt codegen - */ -struct TensorRTCodeGenConfig { - int log_level{0}; - int profile_level{0}; - int test_iter{0}; - size_t max_workspace{1 << 20}; - std::string cmake_version{"3.5"}; - std::string dataset{"Dataset"}; - std::string range_file{""}; - std::string precision{"float32"}; - std::string precision_mode{"strict"}; - std::string tensorrt_root{"/usr/local/cuda"}; - std::vector extern_libs; - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("log_level")); it != obj.end()) { - log_level = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("profile_level")); it != obj.end()) { - profile_level = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("test_iter")); it != obj.end()) { - test_iter = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("max_workspace")); it != obj.end()) { - max_workspace = static_cast((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("cmake_version")); it != obj.end()) { - cmake_version = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("dataset")); it != obj.end()) { - dataset = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("range_file")); it != obj.end()) { - range_file = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("precision")); it != obj.end()) { - precision = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("precision_mode")); it != obj.end()) { - precision_mode = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("tensorrt_root")); it != obj.end()) { - tensorrt_root = std::string((*it).second.cast()); - } - if (auto it = obj.find(ffi::String("extern_libs")); it != obj.end()) { - auto arr = (*it).second.cast<::tvm::ffi::json::Array>(); - extern_libs.clear(); - extern_libs.reserve(arr.size()); - for (const auto& elem : arr) { - extern_libs.push_back(std::string(elem.cast())); - } - } - CODEGEN_CONFIG_PARSE - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc deleted file mode 100644 index f2f5baaa8277..000000000000 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ /dev/null @@ -1,842 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc - */ -#include "tensorrt_opcode.h" - -#include -#include - -#include "../../core/utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::Array TensorRTOpCode::GetDocs() { - stack_.Config(this); - CodeGenBuild(); - if (node()->optype == "tuple") { - for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxOutput(i))) - .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); - } - } else if (node()->optype == "get_item") { - stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxNode())) - .call_arg(DocUtils::ToStr(node()->OutputAt(0)->name)); - } else if (node()->optype != "input") { - SetLayerByValue("Name", DocUtils::ToStr(node()->name)); - for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxOutput(i))) - .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); - } - } - return stack_.GetDocs(); -} - -void TensorRTOpCode::SetPadding(const ffi::String& key) { - const auto& padding = node()->GetTypeArrayAttr("padding"); - if (padding.size() == 1) { - SetLayerByDimsValue("Padding", std::vector{padding[0], padding[0]}, false); - } else if (padding.size() == 2) { - SetLayerByDimsValue("PrePadding", padding, false); - SetLayerByDimsValue("PostPadding", padding, false); - } else if (padding.size() == 4) { - SetLayerByDimsValue("PrePadding", std::vector{padding[0], padding[1]}, false); - SetLayerByDimsValue("PostPadding", std::vector{padding[2], padding[3]}, false); - } else { - LOG_FATAL << "Unexpected padding size" << padding.size(); - } -} - -const ffi::String TensorRTOpCode::DeclareInputs(bool simplify) { - const ffi::String& inputs_ref = "inputs_" + std::to_string(node()->index); - if (node()->parents.size() == 1 && simplify) { - const auto& idx_input = StringUtils::Replace(IdxInput(), "*", ""); - stack_.declare("std::vector", inputs_ref + "_vec") - .declare_arg(node()->inputs.size()) - .declare_arg(idx_input); - } else { - stack_.declare("std::vector", inputs_ref + "_vec", 0, false); - for (size_t i = 0; i < node()->inputs.size(); i++) { - const auto& idx_input = StringUtils::Replace(IdxInput(i), "*", ""); - stack_.declare_arg(idx_input); - } - } - stack_.assign(inputs_ref, inputs_ref + "_vec.data()", "ITensor**"); - return inputs_ref; -} - -const ffi::String TensorRTOpCode::DType(const DataType& dtype) { - const ffi::String& dtype_name = - BaseOpCode::DType(dtype); - ffi::String dtype_enum; - if (dtype_name == "int8") { - dtype_enum = "DataType::kINT8"; - } else if (dtype_name == "int32") { - dtype_enum = "DataType::kINT32"; - } else if (dtype_name == "int64") { - dtype_enum = "DataType::kINT32"; - } else if (dtype_name == "float16") { - dtype_enum = "DataType::kHALF"; - } else if (dtype_name == "float32") { - dtype_enum = "DataType::kFLOAT"; - } else { - LOG_FATAL << "Unexpected dtype for TensorRT " << dtype_name; - } - return dtype_enum; -} - -template -const ffi::String TensorRTOpCode::ToDims(const std::vector& dims, bool use_ndim) { - if (dims.size() == 2 && !use_ndim) { - return "DimsHW{" + std::to_string(dims[0]) + "," + std::to_string(dims[1]) + "}"; - } - ffi::String dims_str = "Dims({" + std::to_string(dims.size()) + ",{"; - for (size_t i = 0; i < dims.size(); i++) { - dims_str = dims_str + std::to_string(dims[i]) + (i < dims.size() - 1 ? "," : ""); - } - dims_str = dims_str + "}})"; - return dims_str; -} - -const ffi::String TensorRTOpCode::ToDims(const ffi::Array& dims, bool use_ndim) { - std::vector int_dims; - for (const auto& d : dims) { - int_dims.push_back(d->value); - } - return ToDims(int_dims, use_ndim); -} - -const ffi::String TensorRTOpCode::AttrToDims(const ffi::String& key, bool use_ndim) { - const auto& dims = node()->GetTypeArrayAttr(key); - return ToDims(dims, use_ndim); -} - -const size_t TensorRTOpCode::ToReduceAxis(const std::vector& axes, size_t ndim) { - size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; - size_t reduce_axis = 0; - for (const auto& a : axes) { - reduce_axis += 1 << CommonUtils::GetIndex(a, valid_ndim); - } - return reduce_axis; -} - -const size_t TensorRTOpCode::AttrToReduceAxis(const ffi::String& key, size_t ndim) { - std::vector axes; - if (node()->GetAttr(key, &axes)) { - return ToReduceAxis(axes, ndim); - } - int axis; - TVM_FFI_ICHECK(node()->GetAttr(key, &axis)) << "Can not get axes from attribute key " << key; - return ToReduceAxis(std::vector{axis}, ndim); -} - -const size_t TensorRTOpCode::AttrToAxis(const ffi::String& key, size_t ndim) { - size_t valid_ndim = ndim == 0 ? node()->InputAt(0)->Ndim() : ndim; - int axis = node()->GetTypeAttr(key); - return CommonUtils::GetIndex(axis, valid_ndim); -} - -template -void TensorRTOpCode::SetLayerByAttr(const ffi::String& method, const ffi::String& key) { - stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); -} - -template -void TensorRTOpCode::SetLayerByValue(const ffi::String& method, const T& value) { - stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).call_arg(value); -} - -void TensorRTOpCode::SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, - bool use_ndim) { - stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) - .call_arg(AttrToDims(key, use_ndim)); -} - -template -void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, const std::vector& value, - bool use_ndim) { - stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) - .call_arg(ToDims(value, use_ndim)); -} - -void TensorRTOpCode::SetLayerByDimsValue(const ffi::String& method, - const ffi::Array& value, bool use_ndim) { - stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) - .call_arg(ToDims(value, use_ndim)); -} - -#define TENSORRT_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const ffi::String& func_name) : TensorRTOpCode(func_name) {} - -#define TENSORRT_FLAG_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const ffi::String& func_name, const ffi::String& symbol) : TensorRTOpCode(func_name) { \ - symbol_ = symbol; \ - } \ - \ - private: \ - ffi::String symbol_; - -class TensorRTActivationCodeGen : public TensorRTOpCode { - public: - explicit TensorRTActivationCodeGen(const ffi::String& symbol) : TensorRTOpCode("Activation") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().call_arg("ActivationType::k" + symbol_); - if (node()->optype == "nn.leaky_relu") { - SetLayerByAttr("Alpha", "alpha"); - } else if (node()->optype == "clip") { - SetLayerByAttr("Alpha", "min"); - SetLayerByAttr("Beta", "max"); - } - } - - private: - ffi::String symbol_; -}; - -class TensorRTAdaptivePool2dCodeGen : public TensorRTOpCode { - public: - TENSORRT_FLAG_OP_CODEGEN_METHODS(TensorRTAdaptivePool2dCodeGen) - - protected: - void CodeGenBuild() final { - const auto& input = node()->InputAt(0); - const auto& output = node()->OutputAt(0); - std::vector in_sizes{input->DimAt("H")->value, input->DimAt("W")->value}; - std::vector out_sizes{output->DimAt("H")->value, output->DimAt("W")->value}; - std::vector stride, kernel; - for (size_t i = 0; i < 2; i++) { - stride.push_back(in_sizes[i] / out_sizes[i]); - kernel.push_back((in_sizes[i] - (out_sizes[i] - 1) * stride[i])); - } - const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; - stack_.op_call() - .op_input_arg() - .call_arg("PoolingType::k" + symbol_) - .call_arg(ToDims(kernel, false)); - SetLayerByDimsValue("Stride" + suffix, stride, false); - } -}; - -class TensorRTArgmaxminCodeGen : public TensorRTOpCode { - public: - explicit TensorRTArgmaxminCodeGen(const ffi::String& symbol) : TensorRTOpCode("TopK") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - TVM_FFI_ICHECK(node()->GetTypeAttr("keepdims")) << "Only support argsort with keepdims"; - stack_.op_call() - .op_input_arg() - .call_arg("TopKOperation::k" + symbol_) - .op_arg("keepdims", "") - .call_arg(AttrToReduceAxis()); - } - - private: - ffi::String symbol_; -}; - -class TensorRTAstypeCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTAstypeCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .func_call("setOutputType", std::nullopt, DocUtils::ToPtr(IdxNode())) - .call_arg(0) - .op_dtype_arg(node()->OutputAt(0)->dtype); - } -}; - -class TensorRTBatchMatmulCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTBatchMatmulCodeGen) - - protected: - void CodeGenBuild() final { - bool trans_a = node()->GetTypeAttr("transpose_a"); - bool trans_b = node()->GetTypeAttr("transpose_b"); - stack_.op_call() - .op_input_arg() - .call_arg(trans_a ? "MatrixOperation::kTRANSPOSE" : "MatrixOperation::kNONE") - .op_input_arg(1) - .call_arg(trans_b ? "MatrixOperation::kTRANSPOSE" : "MatrixOperation::kNONE"); - } -}; - -class TensorRTConcatCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTConcatCodeGen) - - protected: - void CodeGenBuild() final { - const auto& producer = node()->ProducerOf(0); - TVM_FFI_ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") - << "Concat expect parent as tuple, get " << node(); - stack_.op_call().call_arg(IdxNodeBase(producer)).call_arg(producer->inputs.size()); - SetLayerByValue("Axis", AttrToAxis()); - } -}; - -class TensorRTConstantCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTConstantCodeGen) - - protected: - void CodeGenBuild() final { - TVM_FFI_ICHECK(!node()->HasAttr("scalar")) << "Scalar constant is not supported"; - stack_.op_call().call_arg(ToDims(node()->OutputAt(0)->shape)).op_weight_arg("const"); - } -}; - -class TensorRTConvCodeGen : public TensorRTOpCode { - public: - TensorRTConvCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { - use_bias_ = use_bias; - } - - protected: - void CodeGenBuild() final { - const auto& weight = node()->WeightAt("weight"); - std::vector kernel_size; - for (size_t i = 0; i < weight->Ndim(); i++) { - if (weight->layout[i].name() == "I" || weight->layout[i].name() == "O") { - continue; - } - kernel_size.push_back(weight->DimAt(i)->value); - } - stack_.op_call() - .op_input_arg() - .call_arg(weight->DimAt("O")) - .call_arg(ToDims(kernel_size, false)) - .op_weight_arg("weight"); - if (use_bias_) { - stack_.op_weight_arg("bias"); - } else { - stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); - } - const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; - SetLayerByDimsAttr("Stride" + suffix, "strides", false); - SetLayerByDimsAttr("Dilation" + suffix, "dilation", false); - SetLayerByAttr("NbGroups", "groups"); - SetPadding(); - } - - private: - bool use_bias_; -}; - -class TensorRTElemwiseCodeGen : public TensorRTOpCode { - public: - explicit TensorRTElemwiseCodeGen(const ffi::String& symbol) : TensorRTOpCode("ElementWise") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg(false).call_arg("ElementWiseOperation::k" + symbol_); - } - - private: - ffi::String symbol_; -}; - -class TensorRTGetItemCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTGetItemCodeGen) - - protected: - void CodeGenBuild() final { - int index = node()->GetTypeAttr("index"); - const auto& producer = node()->ProducerOf(0); - stack_.assign(IdxNode(), IdxOutputBase(producer, index), "auto"); - } -}; - -class TensorRTInputCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTInputCodeGen) - - protected: - void CodeGenBuild() final { - const auto& output = node()->OutputAt(0); - stack_.op_call() - .call_arg(DocUtils::ToStr(output->name)) - .op_dtype_arg(output->dtype) - .call_arg(ToDims(output->shape)); - } -}; - -class TensorRTLinearCodeGen : public TensorRTOpCode { - public: - TensorRTLinearCodeGen(const ffi::String& func_name, bool use_bias) : TensorRTOpCode(func_name) { - use_bias_ = use_bias; - } - - protected: - void CodeGenBuild() final { - const auto& weight = node()->WeightAt("weight"); - stack_.op_call().op_input_arg().call_arg(weight->DimAt("O")).op_weight_arg("weight"); - if (use_bias_) { - stack_.op_weight_arg("bias"); - } else { - stack_.call_arg(DocUtils::ToIndex("mWeights", DocUtils::ToStr(node()->name + ".bias"))); - } - } - - private: - bool use_bias_; -}; - -class TensorRTMatmulCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTMatmulCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .call_arg("MatrixOperation::kNONE") - .op_input_arg(1) - .call_arg("MatrixOperation::kNONE"); - } -}; - -class TensorRTPadCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTPadCodeGen) - - protected: - void CodeGenBuild() final { - const auto& pad_width = node()->GetTypeArrayAttr("pad_width"); - TVM_FFI_ICHECK(pad_width.size() % 2 == 0) - << "pad_width should be multiple of 2, get " << node(); - std::vector pre_padding{2, 0}, post_padding{2, 0}; - const auto& input = node()->InputAt(0); - for (size_t i = 0; i < input->Ndim(); i++) { - if (input->layout[i].name() == "H") { - pre_padding[0] = pad_width[i * 2]; - post_padding[0] = pad_width[i * 2 + 1]; - } else if (input->layout[i].name() == "W") { - pre_padding[1] = pad_width[i * 2]; - post_padding[1] = pad_width[i * 2 + 1]; - } - } - stack_.op_call().op_input_arg().call_arg(ToDims(pre_padding)).call_arg(ToDims(post_padding)); - } -}; - -class TensorRTPermuteDimsCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTPermuteDimsCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { - axes.push_back(i - 1); - } - } - const ffi::String& perm_ref = "perm_" + std::to_string(node()->index); - stack_.op_call().op_input_arg().declare("Permutation", perm_ref); - for (size_t i = 0; i < axes.size(); i++) { - stack_.assign(perm_ref + ".order[" + std::to_string(i) + "]", - CommonUtils::GetIndex(axes[i], node()->InputAt(0)->Ndim())); - } - SetLayerByValue("FirstTranspose", perm_ref); - } -}; - -class TensorRTPool2dCodeGen : public TensorRTOpCode { - public: - explicit TensorRTPool2dCodeGen(const ffi::String& symbol) : TensorRTOpCode("PoolingNd") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .call_arg("PoolingType::k" + symbol_) - .call_arg(AttrToDims("pool_size", false)); - const ffi::String& suffix = CompareVersion(8, 0, 0) >= 0 ? "Nd" : ""; - SetLayerByDimsAttr("Stride" + suffix, "strides", false); - if (node()->GetTypeAttr("ceil_mode")) { - SetLayerByValue("PaddingMode", "PaddingMode::kEXPLICIT_ROUND_UP"); - } - if (node()->optype == "nn.avg_pool2d") { - SetLayerByValue("AverageCountExcludesPadding", false); - } - SetPadding(); - } - - private: - ffi::String symbol_; -}; - -class TensorRTReduceCodeGen : public TensorRTOpCode { - public: - explicit TensorRTReduceCodeGen(const ffi::String& symbol) : TensorRTOpCode("Reduce") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .call_arg("ReduceOperation::k" + symbol_) - .call_arg(AttrToReduceAxis()) - .op_arg("keepdims", ""); - } - - private: - ffi::String symbol_; -}; - -class TensorRTReshapeCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTReshapeCodeGen) - - protected: - void CodeGenBuild() final { - const auto& output = node()->OutputAt(0); - stack_.op_call().op_input_arg(); - SetLayerByDimsValue("ReshapeDimensions", output->shape); - } -}; - -class TensorRTResize2dCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTResize2dCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - const auto& method = node()->GetTypeAttr("method"); - ffi::String resize_mode; - if (method == "linear") { - resize_mode = "LINEAR"; - } else if (method == "nearest_neighbor") { - resize_mode = "NEAREST"; - } else { - LOG_FATAL << "Unexpected resize method " << method; - } - SetLayerByValue("ResizeMode", "ResizeMode::k" + resize_mode); - SetLayerByValue("SelectorForSinglePixel", "ResizeSelector::kFORMULA"); - const auto& transformation_mode = - node()->GetTypeAttr("coordinate_transformation_mode"); - // set transformation - if (transformation_mode == "align_corners") { - SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kALIGN_CORNERS"); - } else if (transformation_mode == "asymmetric") { - SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kASYMMETRIC"); - } else if (transformation_mode == "tf_half_pixel_for_nn") { - SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); - } else if (transformation_mode == "pytorch_half_pixel") { - SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); - } else if (transformation_mode == "half_pixel") { - SetLayerByValue("CoordinateTransformation", "ResizeCoordinateTransformation::kHALF_PIXEL"); - } else { - LOG_FATAL << "Unexpected transformation_mode " << transformation_mode; - } - // set round - const auto& rounding_method = node()->GetTypeAttr("rounding_method"); - if (transformation_mode == "tf_half_pixel_for_nn") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kCEIL"); - } else if (rounding_method == "floor") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kFLOOR"); - } else if (rounding_method == "ceil") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kCEIL"); - } else if (rounding_method == "round_prefer_floor") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_DOWN"); - } else if (rounding_method == "round_prefer_ceil") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); - } else if (rounding_method == "round") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); - } else if (rounding_method == "") { - SetLayerByValue("NearestRounding", "ResizeRoundMode::kHALF_UP"); - } else { - LOG_FATAL << "Unexpected rounding_method " << rounding_method; - } - // set output dims - SetLayerByDimsValue("OutputDimensions", node()->OutputAt(0)->shape); - } -}; - -class TensorRTSoftmaxCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTSoftmaxCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - SetLayerByValue("Axes", AttrToReduceAxis()); - } -}; - -class TensorRTSquareCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTSquareCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_input_arg().call_arg("ElementWiseOperation::kPROD"); - } -}; - -class TensorRTStridedSliceCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTStridedSliceCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - axes.push_back(i); - } - } - std::vector begin(node()->InputAt(0)->Ndim(), 0); - std::vector strides(node()->InputAt(0)->Ndim(), 1); - const auto& attr_begin = node()->GetTypeArrayAttr("begin"); - for (size_t i = 0; i < axes.size(); i++) { - size_t max_dim = static_cast(node()->InputAt(0)->DimAt(axes[i])->value); - begin[axes[i]] = CommonUtils::GetIndex(attr_begin[i], max_dim); - } - std::vector attr_strides; - if (node()->GetAttr("strides", &attr_strides)) { - for (size_t i = 0; i < axes.size(); i++) { - strides[axes[i]] = static_cast(attr_strides[i]); - } - } - stack_.op_call() - .op_input_arg() - .call_arg(ToDims(begin)) - .call_arg(ToDims(node()->OutputAt(0)->shape)) - .call_arg(ToDims(strides)); - } -}; - -class TensorRTTakeCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTTakeCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg(false).call_arg(AttrToAxis()); - if (node()->InputAt(0)->Ndim() == node()->InputAt(1)->Ndim()) { - SetLayerByValue("Mode", "GatherMode::kELEMENT"); - } - } -}; - -class TensorRTTopkCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTTopkCodeGen) - - protected: - void CodeGenBuild() final { - const ffi::String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; - stack_.op_call() - .op_input_arg() - .call_arg("TopKOperation::k" + symbol) - .op_arg("k", "") - .call_arg(AttrToReduceAxis()); - } -}; - -class TensorRTTupleCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTTupleCodeGen) - - protected: - void CodeGenBuild() final { - const auto& inputs_ref = DeclareInputs(); - stack_.assign(IdxNode(), inputs_ref, "auto"); - } -}; - -class TensorRTUnaryCodeGen : public TensorRTOpCode { - public: - explicit TensorRTUnaryCodeGen(const ffi::String& symbol) : TensorRTOpCode("Unary") { - symbol_ = symbol; - } - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().call_arg("UnaryOperation::k" + symbol_); - } - - private: - ffi::String symbol_; -}; - -class TensorRTWhereCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTWhereCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false); } -}; - -class TensorRTPluginOpCodeGen : public TensorRTOpCode { - public: - TENSORRT_OP_CODEGEN_METHODS(TensorRTPluginOpCodeGen) - - protected: - void CodeGenBuild() final { - const auto& producer = node()->ParentAt(0); - TVM_FFI_ICHECK(producer->optype == "tuple") - << "Only support tensorrt plugin with tuple, get " << producer; - - const auto& plugin = GetPlugin(node()->optype); - const auto& input_ref = "inputs_" + std::to_string(producer->index); - const ffi::String& func_name = "plugin::" + node()->optype + "DynamicPlugin"; - const ffi::String& plugin_ref = "plugin_" + std::to_string(node()->index); - const ffi::String& layouts_ref = "layouts_" + std::to_string(node()->index); - stack_.declare("std::vector", layouts_ref, 0, false); - for (const auto& i : node()->GetInputs()) { - stack_.declare_arg(DocUtils::ToStr(i->layout.name())); - } - stack_.func_call(func_name, DocUtils::ToDeclare("auto", plugin_ref)) - .call_arg(DocUtils::ToStr(node()->name)); - for (const auto& a : plugin->attrs) { - stack_.call_arg(GetAttrDoc(a->name, a->type)); - } - stack_.call_arg(layouts_ref); - stack_.op_call().call_arg(input_ref).call_arg(plugin->inputs.size()).call_arg(plugin_ref); - } -}; - -const std::shared_ptr>> -GetTensorRTOpCodes() { - static auto map = - std::make_shared>>(); - if (!map->empty()) return map; - // unary ops - map->emplace("abs", std::make_shared("ABS")); - map->emplace("acos", std::make_shared("ACOS")); - map->emplace("acosh", std::make_shared("ACOSH")); - map->emplace("asin", std::make_shared("ASIN")); - map->emplace("asinh", std::make_shared("ASINH")); - map->emplace("atan", std::make_shared("ATAN")); - map->emplace("atanh", std::make_shared("ATANH")); - map->emplace("ceil", std::make_shared("CEIL")); - map->emplace("cos", std::make_shared("COS")); - map->emplace("cosh", std::make_shared("COSH")); - map->emplace("erf", std::make_shared("ERF")); - map->emplace("exp", std::make_shared("EXP")); - map->emplace("floor", std::make_shared("FLOOR")); - map->emplace("log", std::make_shared("LOG")); - map->emplace("negative", std::make_shared("NEG")); - map->emplace("round", std::make_shared("ROUND")); - map->emplace("sin", std::make_shared("SIN")); - map->emplace("sinh", std::make_shared("SINH")); - map->emplace("sqrt", std::make_shared("SQRT")); - map->emplace("tan", std::make_shared("TAN")); - - // elemwise ops - map->emplace("add", std::make_shared("SUM")); - map->emplace("divide", std::make_shared("DIV")); - map->emplace("equal", std::make_shared("EQUAL")); - map->emplace("floor_divide", std::make_shared("FLOOR_DIV")); - map->emplace("greater", std::make_shared("GREATER")); - map->emplace("less", std::make_shared("LESS")); - map->emplace("maximum", std::make_shared("MAX")); - map->emplace("minimum", std::make_shared("MIN")); - map->emplace("multiply", std::make_shared("PROD")); - map->emplace("power", std::make_shared("POW")); - map->emplace("subtract", std::make_shared("SUB")); - - // reduce ops - map->emplace("max", std::make_shared("MAX")); - map->emplace("mean", std::make_shared("AVG")); - map->emplace("min", std::make_shared("MIN")); - map->emplace("sum", std::make_shared("SUM")); - - // math ops - map->emplace("argmax", std::make_shared("MAX")); - map->emplace("argmin", std::make_shared("MIN")); - map->emplace("astype", std::make_shared("Identity")); - map->emplace("concat", std::make_shared("Concatenation")); - map->emplace("expand_dims", std::make_shared("Shuffle")); - map->emplace("matmul", std::make_shared("MatrixMultiply")); - map->emplace("permute_dims", std::make_shared("Shuffle")); - map->emplace("reshape", std::make_shared("Shuffle")); - map->emplace("square", std::make_shared("ElementWise")); - map->emplace("squeeze", std::make_shared("Shuffle")); - map->emplace("strided_slice", std::make_shared("Slice")); - map->emplace("take", std::make_shared("Gather")); - map->emplace("topk", std::make_shared("TopK")); - map->emplace("where", std::make_shared("Select")); - - // create ops - map->emplace("constant", std::make_shared("Constant")); - - // activation ops - map->emplace("clip", std::make_shared("CLIP")); - map->emplace("sigmoid", std::make_shared("SIGMOID")); - map->emplace("tanh", std::make_shared("TANH")); - map->emplace("nn.relu", std::make_shared("RELU")); - map->emplace("nn.leaky_relu", std::make_shared("LEAKY_RELU")); - - // nn ops - map->emplace("nn.adaptive_avg_pool2d", - std::make_shared("PoolingNd", "AVERAGE")); - map->emplace("nn.avg_pool2d", std::make_shared("AVERAGE")); - map->emplace("nn.batch_matmul", std::make_shared("MatrixMultiply")); - map->emplace("nn.conv2d", std::make_shared("ConvolutionNd", false)); - map->emplace("nn.max_pool2d", std::make_shared("MAX")); - map->emplace("nn.pad", std::make_shared("Padding")); - map->emplace("nn.softmax", std::make_shared("SoftMax")); - - // image ops - map->emplace("image.resize2d", std::make_shared("Resize")); - - // special op - map->emplace("input", std::make_shared("Input")); - map->emplace("get_item", std::make_shared("")); - map->emplace("tuple", std::make_shared("")); - map->emplace("plugin", std::make_shared("PluginV2")); - - // msc ops - map->emplace("msc.conv2d_bias", std::make_shared("ConvolutionNd", true)); - map->emplace("msc.linear", std::make_shared("FullyConnected", false)); - map->emplace("msc.linear_bias", std::make_shared("FullyConnected", true)); - - return map; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h deleted file mode 100644 index ddf7fb1522be..000000000000 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/tensorrt_opcode.h - * \brief TensorRT codegen for MSCJoint. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ - -#include -#include -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TensorRTOpCode; -typedef OpCodeStack TensorRTOpCodeStack; - -/*! - * \brief CodeGen for relax op - */ -class TensorRTOpCode : public BaseOpCode { - public: - /*! - * \brief The constructor of BaseOpDocsifier - * \param func_name the function name for the node. - * \param config the config json for the node. - */ - explicit TensorRTOpCode(const ffi::String& func_name) - : BaseOpCode(func_name) {} - - /*! \brief Convert node to docs*/ - const ffi::Array GetDocs() final; - - /*! \brief Get func_name for the default node*/ - const ffi::String callee_name() final { - return "network->add" + BaseOpCode::callee_name(); - } - - /*! \brief Get valid return name for the default node*/ - const ffi::String ret_name() final { return "auto " + IdxNode(); } - - /*! \brief Get the dtype from the datatype*/ - const ffi::String DType(const DataType& dtype) final; - - protected: - TensorRTOpCodeStack stack_; - - /*! \brief Convert op build*/ - virtual void CodeGenBuild() = 0; - - /*! \brief Set padding for the layer*/ - void SetPadding(const ffi::String& key = "padding"); - - /*! \brief Declare the inputs*/ - const ffi::String DeclareInputs(bool simplify = true); - - /*! \brief Get the tensorrt dims from dims*/ - template - const ffi::String ToDims(const std::vector& dims, bool use_ndim = true); - const ffi::String ToDims(const ffi::Array& dims, bool use_ndim = true); - - /*! \brief Get the tensorrt dims from attribute*/ - const ffi::String AttrToDims(const ffi::String& key, bool use_ndim = true); - - /*! \brief Get the tensorrt reduce axis from dims*/ - const size_t ToReduceAxis(const std::vector& axes, size_t ndim = 0); - - /*! \brief Get the tensorrt reduce axis from attribute*/ - const size_t AttrToReduceAxis(const ffi::String& key = "axis", size_t ndim = 0); - - /*! \brief Get the attribute axis from attribute*/ - const size_t AttrToAxis(const ffi::String& key = "axis", size_t ndim = 0); - - /*! \brief Set layer by attribute*/ - template - void SetLayerByAttr(const ffi::String& method, const ffi::String& key); - - /*! \brief Set layer by value*/ - template - void SetLayerByValue(const ffi::String& method, const T& value); - - /*! \brief Set layer by dims attribute*/ - void SetLayerByDimsAttr(const ffi::String& method, const ffi::String& key, bool use_ndim = true); - - /*! \brief Set layer by dims value*/ - template - void SetLayerByDimsValue(const ffi::String& method, const std::vector& value, - bool use_ndim = true); - void SetLayerByDimsValue(const ffi::String& method, const ffi::Array& value, - bool use_ndim = true); -}; - -/*! - * \brief Get the map of available TensorRTOpCode, use optype as key - * \return Map of - */ -const std::shared_ptr>> -GetTensorRTOpCodes(); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TENSORRT_TENSORRT_OPCODE_H_ diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc deleted file mode 100644 index 4b6beb1164ad..000000000000 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ /dev/null @@ -1,927 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tensorrt/transform_tensorrt.cc - * \brief Pass for transform the function to tensorrt. - */ - -#include -#include -#include -#include -#include -#include - -#include "../../../../relax/transform/utils.h" -#include "../../../../support/scalars.h" -#include "../../core/transform/rewrite_utils.h" -#include "../../core/utils.h" - -namespace tvm { -namespace relax { -using namespace tvm::contrib::msc; - -struct TensorRTTransConfig { - // Whether to cast linear to conv - bool linear_to_conv{true}; - std::vector version{0, 0, 0}; - - void Load(ffi::json::Object obj) { - namespace json = ::tvm::ffi::json; - if (auto it = obj.find(ffi::String("linear_to_conv")); it != obj.end()) { - linear_to_conv = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("version")); it != obj.end()) { - auto arr = (*it).second.cast(); - version.clear(); - version.reserve(arr.size()); - for (const auto& elem : arr) { - version.push_back(static_cast(elem.cast())); - } - } - } -}; - -const TensorRTTransConfig ParseConfig(const ffi::String& config_str) { - TensorRTTransConfig config; - if (config_str.size() > 0) { - namespace json = ::tvm::ffi::json; - config.Load(json::Parse(std::string(config_str)).cast()); - } - return config; -} - -using FRewriteTensorRT = - ffi::TypedFunction& new_calls, const ffi::String& config)>; - -const ffi::Array BroadcastShape(const ffi::Array& src_shape, - const ffi::Array& out_shape) { - size_t diff = out_shape.size() - src_shape.size(); - ffi::Array leading_shape, tailing_shape; - for (size_t i = 0; i < diff; i++) { - leading_shape.push_back(Integer(1)); - } - for (const auto& s : src_shape) { - tailing_shape.push_back(s); - leading_shape.push_back(s); - } - for (size_t i = 0; i < diff; i++) { - tailing_shape.push_back(Integer(1)); - } - if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) { - return tailing_shape; - } - TVM_FFI_ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) - << "Only support elemwise ops with leading or tailing expand"; - return leading_shape; -} - -Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = ExprUtils::GetShape(call->args[0]); - const auto& shape_b = ExprUtils::GetShape(call->args[1]); - const auto& shape_out = ExprUtils::GetShape(var); - static const Op& reshape_op = Op::Get("relax.reshape"); - if (shape_a.size() > shape_b.size()) { - const auto& exp_shape = BroadcastShape(shape_b, shape_out); - const auto& expand_b = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); - } else if (shape_a.size() < shape_b.size()) { - const auto& exp_shape = BroadcastShape(shape_a, shape_out); - const auto& expand_a = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); - return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); - } - return call; -} - -Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - if (new_calls.count(call->args[0]) && - new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { - const auto& reshape = Downcast(builder->LookupBinding(Downcast(call->args[0]))); - if (reshape->op != Op::Get("relax.reshape")) { - return call; - } - const auto& conv2d = Downcast(builder->LookupBinding(Downcast(reshape->args[0]))); - if (conv2d->op != Op::Get("relax.nn.conv2d")) { - return call; - } - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& bias_shape = ExprUtils::GetShape(call->args[1]); - const auto* conv_attrs = conv2d->attrs.as(); - if (conv_attrs->data_layout == "NCHW") { - // expand bias reshape - ffi::Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; - static const Op& reshape_op = Op::Get("relax.reshape"); - const auto& exp_bias = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_bias"), reshape_op, - {call->args[1], ShapeExpr(exp_bias_shape)}); - // redirect to conv2d - static const Op& add_op = Op::Get("relax.add"); - const auto& exp_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_add"), - add_op, {reshape->args[0], exp_bias}); - // reduce output - return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, - call->span); - } else { - LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout; - } - } - return RewriteElemwise(builder, var, call, new_calls, config); -} - -Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& out_dtype = ExprUtils::GetDataType(var); - const auto* src_attrs = src_call->attrs.as(); - TVM_FFI_ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) - << "Unexpected out dtype " << out_dtype; - static const Op& topk_op = Op::Get("relax.topk"); - auto topk_attrs = ffi::make_object(); - topk_attrs->k = 1; - if (src_attrs->axis.has_value()) { - topk_attrs->axis = src_attrs->axis.value(); - } - topk_attrs->largest = call->op == Op::Get("relax.argmax"); - topk_attrs->ret_type = "both"; - topk_attrs->dtype = out_dtype; - // change to topk - const auto& topk = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "topk"), topk_op, - {call->args[0]}, Attrs(topk_attrs)); - const auto& get_name = ExprUtils::GetSpanName(call, ".1"); - const auto& get_item = - TupleGetItem(topk, 1, SpanUtils::CreateWithAttr(msc_attr::kName, get_name)); - if (src_attrs->keepdims) { - return get_item; - } - const auto& get_item_var = builder->Emit(get_item, get_name); - static const Op& reshape_op = Op::Get("relax.reshape"); - const auto& output_shape = ExprUtils::GetShape(var); - return Call(reshape_op, {get_item_var, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, - call->span); -} - -Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - const auto* src_attrs = src_call->attrs.as(); - - // define dims - const auto& in_q_shape = ExprUtils::GetShape(call->args[0]); - const auto& in_v_shape = ExprUtils::GetShape(call->args[2]); - const auto& batch_size = in_q_shape[0]; - const auto& seq_len = in_q_shape[1]; - const auto& num_head = in_q_shape[2]; - const auto& head_dim = in_q_shape[3]; - const auto& seq_len_kv = in_v_shape[1]; - const auto& head_dim_v = in_v_shape[3]; - - // create ops - static const Op& permute_dims_op = Op::Get("relax.permute_dims"); - static const Op& reshape_op = Op::Get("relax.reshape"); - static const Op& matmul_op = Op::Get("relax.matmul"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& add_op = Op::Get("relax.add"); - static const Op& divide_op = Op::Get("relax.divide"); - static const Op& sqrt_op = Op::Get("relax.sqrt"); - static const Op& softmax_op = Op::Get("relax.nn.softmax"); - static const Op& tril_op = Op::Get("relax.tril"); - static const Op& max_op = Op::Get("relax.max"); - static const Op& sum_op = Op::Get("relax.sum"); - static const Op& subtract_op = Op::Get("relax.subtract"); - static const Op& exp_op = Op::Get("relax.exp"); - - // prepare q,k,v - auto permute_attrs = ffi::make_object(); - ffi::Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; - permute_attrs->axes = axes; - const auto& q_trans = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op, - {call->args[0]}, Attrs(permute_attrs)); - const auto& k_trans = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), permute_dims_op, - {call->args[1]}, Attrs(permute_attrs)); - const auto& v_trans = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op, - {call->args[2]}, Attrs(permute_attrs)); - ffi::Array q_shape({batch_size * num_head, seq_len, head_dim}); - const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"), - reshape_op, {q_trans, ShapeExpr(q_shape)}); - ffi::Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); - const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"), - reshape_op, {k_trans, ShapeExpr(k_shape)}); - ffi::Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); - const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"), - reshape_op, {v_trans, ShapeExpr(v_shape)}); - auto reduce_permute_attrs = ffi::make_object(); - ffi::Array v_axes{Integer(0), Integer(2), Integer(1)}; - reduce_permute_attrs->axes = v_axes; - // transpose for batch_matmul - const auto& k_reshape_trans = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape_trans"), - permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs)); - - // calculate product - auto matmul_attrs = ffi::make_object(); - matmul_attrs->out_dtype = in_dtype; - const auto& qk_prod = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op, - {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); - Expr p_scale; - if (src_attrs->scale.defined()) { - double value = static_cast(src_attrs->scale.value()->value); - const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), - value, in_dtype, 3); - p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), multiply_op, - {qk_prod, scale}); - } else { - double value = static_cast(Downcast(head_dim)->value); - const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), - value, in_dtype, 3); - const auto& sqrt_scale = RewriteUtils::MakeCall( - builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale}); - p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), divide_op, - {qk_prod, sqrt_scale}); - } - - // bias - Expr prod = p_scale; - if (call->args.size() == 4) { - ffi::Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; - ffi::Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; - const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"), - reshape_op, {prod, ShapeExpr(exp_shape)}); - const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"), - add_op, {prod_exp, call->args[3]}); - prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_reduce"), reshape_op, - {prod_add, ShapeExpr(reduce_shape)}); - } - - // causal_mask - Expr s_value; - if (!src_attrs->causal_mask.has_value()) { - auto softmax_attrs = ffi::make_object(); - softmax_attrs->axis = 2; - s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, - {prod}, Attrs(softmax_attrs)); - } else { - const auto& causal_mask = src_attrs->causal_mask.value(); - PrimValue tril_k; - if (causal_mask == "TopLeft") { - tril_k = PrimValue(Integer(0)); - } else if (causal_mask == "BottomRight") { - tril_k = PrimValue(seq_len - seq_len_kv); - } else { - LOG_FATAL << "Unexpected causal_mask " << causal_mask; - } - const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"), - tril_op, {prod, tril_k}); - auto reduce_attrs = ffi::make_object(); - ffi::Array axis{Integer(2)}; - reduce_attrs->axis = axis; - reduce_attrs->keepdims = true; - const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"), - max_op, {prod}, Attrs(reduce_attrs)); - const auto& p_diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_diff"), - subtract_op, {p_masked, p_max}); - const auto& p_exp = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), exp_op, {p_diff}); - const auto& p_masked_exp = RewriteUtils::MakeCall( - builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, {p_exp, tril_k}); - const auto& p_masked_sum = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked_sum"), sum_op, - {p_masked_exp}, Attrs(reduce_attrs)); - s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), divide_op, - {p_masked_exp, p_masked_sum}); - } - - // final calculation - const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"), - matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); - ffi::Array o_shape{batch_size, num_head, seq_len, head_dim_v}; - return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - const auto* src_attrs = src_call->attrs.as(); - // define expand shape - ffi::Array exp_shape(input_shape.size(), Integer(1)); - exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); - - // create eps constant - const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), - src_attrs->epsilon, in_dtype); - - // create ops - static const Op& add_op = Op::Get("relax.add"); - static const Op& divide_op = Op::Get("relax.divide"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& reshape_op = Op::Get("relax.reshape"); - static const Op& sqrt_op = Op::Get("relax.sqrt"); - static const Op& subtract_op = Op::Get("relax.subtract"); - - // scale factor: gamma/sqrt(var + epsilon) - const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), - add_op, {call->args[4], eps}); - const auto& sqrt = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); - const auto& scale_factor = RewriteUtils::MakeCall( - builder, ExprUtils::GetSpanName(call, "scale_factor"), divide_op, {call->args[1], sqrt}); - Expr res = call->args[0]; - // scale - if (src_attrs->scale) { - const auto& exp_scale = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_scale"), reshape_op, - {scale_factor, ShapeExpr(exp_shape)}); - res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, - {res, exp_scale}); - } - // offset - if (src_attrs->center) { - // offset factor: beta-mean*scale_factor - const auto& average = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "average"), - multiply_op, {call->args[3], scale_factor}); - const auto& offset_factor = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset_factor"), subtract_op, - {call->args[2], average}); - const auto& exp_offset = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_offset"), reshape_op, - {offset_factor, ShapeExpr(exp_shape)}); - res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, - {res, exp_offset}); - } - return Tuple(ffi::Array{res}, call->span); -} - -Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& output_shape = ExprUtils::GetShape(var); - Expr concat_input = call->args[0]; - static const Op& concat_op = Op::Get("relax.concat"); - for (size_t i = 0; i < input_shape.size(); i++) { - int64_t in_dim = Downcast(input_shape[i])->value; - int64_t out_dim = Downcast(output_shape[i])->value; - if (in_dim != out_dim) { - ffi::Array concat_inputs(out_dim / in_dim, concat_input); - auto concat_attrs = ffi::make_object(); - concat_attrs->axis = i; - concat_input = RewriteUtils::MakeCall( - builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, - {Tuple(concat_inputs)}, Attrs(concat_attrs)); - } - } - return concat_input; -} - -Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto* src_attrs = src_call->attrs.as(); - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& weight_shape = ExprUtils::GetShape(call->args[1]); - const auto& output_shape = ExprUtils::GetShape(var); - if (src_attrs->data_layout == "NCW") { - ffi::Array new_args; - // expand inputs - ffi::Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), - input_shape[2]}; - ffi::Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), - weight_shape[2]}; - static const Op& reshape_op = Op::Get("relax.reshape"); - new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_input"), - reshape_op, - {call->args[0], ShapeExpr(exp_input_shape)})); - new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), - reshape_op, - {call->args[1], ShapeExpr(exp_weight_shape)})); - // change to conv2d - static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = ffi::make_object(); - conv_attrs->strides = ffi::Array{src_attrs->strides[0], 1}; - conv_attrs->padding = ffi::Array{0, src_attrs->padding[0], 0, src_attrs->padding[1]}; - conv_attrs->dilation = ffi::Array{src_attrs->dilation[0], 1}; - conv_attrs->groups = src_attrs->groups; - conv_attrs->data_layout = "NCHW"; - conv_attrs->kernel_layout = "OIHW"; - conv_attrs->out_layout = "NCHW"; - conv_attrs->out_dtype = src_attrs->out_dtype; - const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp"), - conv2d_op, new_args, Attrs(conv_attrs)); - // reduce output - return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, - call->span); - } else { - LOG_FATAL << "Unexpected data layout " << src_attrs->data_layout; - } - return call; -} - -Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - // 0.5 * x * (1 + erf(sqrt(0.5) * x)) - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - // create ops - static const Op& add_op = Op::Get("relax.add"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& erf_op = Op::Get("relax.erf"); - - const auto& factor = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "factor"), - std::sqrt(0.5), in_dtype, in_dim); - const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), - multiply_op, {factor, call->args[0]}); - const auto& erf = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "erf"), erf_op, {mul}); - const auto& one = - RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); - const auto& add = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, erf}); - const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), - multiply_op, {call->args[0], add}); - const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, - in_dtype, in_dim); - return Call(multiply_op, {half, mul2}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x))) - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - - // create ops - static const Op& add_op = Op::Get("relax.add"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& pow_op = Op::Get("relax.power"); - static const Op& tanh_op = Op::Get("relax.tanh"); - - const auto& pow_factor = RewriteUtils::MakeConstant( - builder, ExprUtils::GetSpanName(call, "pow_factor"), 3, in_dtype, in_dim); - const auto& mul_factor = RewriteUtils::MakeConstant( - builder, ExprUtils::GetSpanName(call, "mul_factor"), 0.044715, in_dtype, in_dim); - const auto& pi_factor = RewriteUtils::MakeConstant( - builder, ExprUtils::GetSpanName(call, "pi_factor"), std::sqrt(2 / M_PI), in_dtype, in_dim); - - const auto& pow = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "pow"), pow_op, - {call->args[0], pow_factor}); - const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), - multiply_op, {mul_factor, pow}); - const auto& add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, - {mul, call->args[0]}); - const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), - multiply_op, {pi_factor, add}); - const auto& tanh = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "tanh"), tanh_op, {mul2}); - const auto& one = - RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); - const auto& add2 = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, tanh}); - const auto& mul3 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul3"), - multiply_op, {call->args[0], add2}); - const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, - in_dtype, in_dim); - return Call(multiply_op, {half, mul3}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - const auto* src_attrs = src_call->attrs.as(); - ffi::Array group_shape = input_shape; - ffi::Array exp_shape(input_shape.size(), Integer(1)); - size_t axis = CommonUtils::GetIndex(src_attrs->channel_axis, input_shape.size()); - int64_t channel_dim = Downcast(input_shape[axis])->value * - Downcast(input_shape[axis + 1])->value / src_attrs->num_groups; - group_shape.Set(axis, Integer(src_attrs->num_groups)); - group_shape.Set(axis + 1, Integer(channel_dim)); - exp_shape.Set(axis, Integer(src_attrs->num_groups)); - - // create eps constant - const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), - src_attrs->epsilon, in_dtype); - - // create ops - static const Op& add_op = Op::Get("relax.add"); - static const Op& divide_op = Op::Get("relax.divide"); - static const Op& mean_op = Op::Get("relax.mean"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& square_op = Op::Get("relax.square"); - static const Op& reshape_op = Op::Get("relax.reshape"); - static const Op& sqrt_op = Op::Get("relax.sqrt"); - static const Op& subtract_op = Op::Get("relax.subtract"); - - // reshape input - const auto& reshape_in = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "reshape_in"), reshape_op, - {call->args[0], ShapeExpr(group_shape)}); - - // mean(input) - auto mean_attrs = ffi::make_object(); - mean_attrs->axis = src_attrs->axes; - mean_attrs->keepdims = true; - const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, - {reshape_in}, Attrs(mean_attrs)); - - // variance: mean((input-mean)*(input-mean)) - const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), - subtract_op, {reshape_in, mean}); - const auto& square = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); - const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), - mean_op, {square}, Attrs(mean_attrs)); - - // sqrt(var + epsilon) - ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), - reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), - add_op, {variance, exp_eps}); - const auto& sqrt = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); - - // diff/sqrt - Expr res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "divide"), divide_op, - {diff, sqrt}); - - // scale - if (src_attrs->scale) { - const auto& exp_gamma = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, - {res, exp_gamma}); - } - // offset - if (src_attrs->center) { - const auto& exp_beta = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, - {res, exp_beta}); - } - // reshape output - return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - const auto* src_attrs = src_call->attrs.as(); - ffi::Array exp_shape(input_shape.size(), Integer(1)); - for (const auto& a : src_attrs->axes) { - size_t index = CommonUtils::GetIndex(static_cast(a->value), input_shape.size()); - exp_shape.Set(index, input_shape[index]); - } - // create eps constant - const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), - src_attrs->epsilon, in_dtype); - - // create ops - static const Op& add_op = Op::Get("relax.add"); - static const Op& divide_op = Op::Get("relax.divide"); - static const Op& mean_op = Op::Get("relax.mean"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& square_op = Op::Get("relax.square"); - static const Op& reshape_op = Op::Get("relax.reshape"); - static const Op& sqrt_op = Op::Get("relax.sqrt"); - static const Op& subtract_op = Op::Get("relax.subtract"); - - // mean(input) - auto mean_attrs = ffi::make_object(); - mean_attrs->axis = src_attrs->axes; - mean_attrs->keepdims = true; - const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, - {call->args[0]}, Attrs(mean_attrs)); - - // variance: mean((input-mean)*(input-mean)) - const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), - subtract_op, {call->args[0], mean}); - const auto& square = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); - const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), - mean_op, {square}, Attrs(mean_attrs)); - - // sqrt(var + epsilon) - ffi::Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), - reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), - add_op, {variance, exp_eps}); - const auto& sqrt = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); - - // diff/sqrt - Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args, call->span); - - // scale - if (src_attrs->scale) { - const auto& exp_gamma = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - const auto& res_var = - RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_scale"), res); - if (src_attrs->center) { - res = Call(multiply_op, {res_var, exp_gamma}); - } else { - res = Call(multiply_op, {res_var, exp_gamma}, Attrs(), call->sinfo_args, call->span); - } - } - // offset - if (src_attrs->center) { - const auto& exp_beta = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - const auto& res_var = - RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_offset"), res); - res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args, call->span); - } - return res; -} - -Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& trt_config = ParseConfig(config); - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = ExprUtils::GetShape(call->args[0]); - const auto& shape_b = ExprUtils::GetShape(call->args[1]); - static const Op& reshape_op = Op::Get("relax.reshape"); - if (call->args[1]->IsInstance() && shape_b.size() == 2 && - trt_config.linear_to_conv) { - const auto& out_shape = ExprUtils::GetShape(var); - PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1); - ffi::Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; - const auto& exp_in = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_in"), - reshape_op, {call->args[0], ShapeExpr(exp_shape)}); - // transpose and expand weight to OIHW - static const Op& permute_dims_op = Op::Get("relax.permute_dims"); - auto permute_attrs = ffi::make_object(); - ffi::Array axes{Integer(1), Integer(0)}; - permute_attrs->axes = axes; - const auto& trans_weight = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "trans_weight"), - permute_dims_op, {call->args[1]}, Attrs(permute_attrs)); - ffi::Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; - const auto& exp_weight = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), reshape_op, - {trans_weight, ShapeExpr(weight_shape)}); - // to conv2d - static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); - auto conv_attrs = ffi::make_object(); - conv_attrs->strides = ffi::Array{1, 1}; - conv_attrs->padding = ffi::Array{0, 0, 0, 0}; - conv_attrs->dilation = ffi::Array{1, 1}; - conv_attrs->groups = 1; - conv_attrs->data_layout = "NCHW"; - conv_attrs->kernel_layout = "OIHW"; - conv_attrs->out_layout = "NCHW"; - conv_attrs->out_dtype = ExprUtils::GetDataType(var); - const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "conv2d"), - conv2d_op, {exp_in, exp_weight}, Attrs(conv_attrs)); - return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), call->sinfo_args, call->span); - } - if (shape_a.size() > shape_b.size()) { - ffi::Array exp_shape(shape_a.size(), Integer(1)); - size_t diff = shape_a.size() - shape_b.size(); - for (size_t i = diff; i < shape_a.size(); i++) { - exp_shape.Set(i, shape_b[i - diff]); - } - const auto& expand_b = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); - } - if (shape_a.size() < shape_b.size()) { - ffi::Array exp_shape(shape_b.size(), Integer(1)); - size_t diff = shape_b.size() - shape_a.size(); - for (size_t i = diff; i < shape_b.size(); i++) { - exp_shape.Set(i, shape_a[i - diff]); - } - const auto& expand_a = - RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); - return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); - } - return call; -} - -Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); - // create 1 constant - const auto& one = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), 1, - in_dtype, input_shape.size()); - - // create ops - static const Op& divide_op = Op::Get("relax.divide"); - static const Op& sqrt_op = Op::Get("relax.sqrt"); - - // expand and divide - const auto& sqrt = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, - {call->args[0]}); - return Call(divide_op, {one, sqrt}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - // create ops - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& sigmoid_op = Op::Get("relax.sigmoid"); - // silu=input*sigmoid(input) - const auto& sigmoid = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sigmoid"), - sigmoid_op, {call->args[0]}); - return Call(multiply_op, {call->args[0], sigmoid}, Attrs(), call->sinfo_args, call->span); -} - -Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& output_shape = ExprUtils::GetShape(var); - static const Op& reshape_op = Op::Get("relax.reshape"); - return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, - call->span); -} - -Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, - const ffi::Map& new_calls, const ffi::String& config) { - const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = ExprUtils::GetShape(call->args[0]); - const auto* src_attrs = src_call->attrs.as(); - size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size()); - std::vector split_begins, split_ends; - // get split begins and ends - if (src_attrs->indices_or_sections->IsInstance()) { - int64_t sections = Downcast(src_attrs->indices_or_sections)->value; - int64_t size = Downcast(input_shape[axis])->value / sections; - for (int64_t i = 0; i < sections; i++) { - split_begins.push_back(i * size); - split_ends.push_back(i * size + size); - } - } else if (src_attrs->indices_or_sections->IsInstance()) { - const auto& indices = Downcast>(src_attrs->indices_or_sections); - int64_t last_index = 0; - for (size_t i = 0; i < indices.size(); ++i) { - split_begins.push_back(last_index); - last_index = indices[i]->value; - split_ends.push_back(last_index); - } - split_begins.push_back(last_index); - split_ends.push_back(Downcast(input_shape[axis])->value); - } else { - LOG_FATAL << "Unexpected indices_or_sections " << src_attrs->indices_or_sections << "(" - << src_attrs->indices_or_sections->GetTypeKey() << ")"; - } - // create strided_slices - ffi::Array outputs; - for (size_t i = 0; i < split_begins.size(); i++) { - static const Op& strided_slice_op = Op::Get("relax.strided_slice"); - const auto& axes = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), axis))}); - const auto& begin = - Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); - const auto& end = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); - const auto& strides = Tuple(ffi::Array{PrimValue(IntImm(DataType::Int(64), 1))}); - auto attrs = ffi::make_object(); - attrs->assume_inbound = true; - const auto& slice = RewriteUtils::MakeCall( - builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), strided_slice_op, - {call->args[0], axes, begin, end, strides}, Attrs(attrs)); - outputs.push_back(slice); - } - return Tuple(outputs, call->span); -} - -// nn ops -TVM_REGISTER_OP("relax.nn.attention") - .set_attr("FRewriteTensorRT", RewriteAttention); -TVM_REGISTER_OP("relax.nn.attention_bias") - .set_attr("FRewriteTensorRT", RewriteAttention); -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FRewriteTensorRT", RewriteBatchNorm); -TVM_REGISTER_OP("relax.nn.conv1d").set_attr("FRewriteTensorRT", RewriteConv1d); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FRewriteTensorRT", RewriteGroupNorm); -TVM_REGISTER_OP("relax.nn.gelu").set_attr("FRewriteTensorRT", RewriteGelu); -TVM_REGISTER_OP("relax.nn.gelu_tanh") - .set_attr("FRewriteTensorRT", RewriteGeluTanh); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FRewriteTensorRT", RewriteLayerNorm); -TVM_REGISTER_OP("relax.nn.silu").set_attr("FRewriteTensorRT", RewriteSilu); - -// elemwise ops -TVM_REGISTER_OP("relax.add").set_attr("FRewriteTensorRT", RewriteAdd); -TVM_REGISTER_OP("relax.divide").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.floor_divide") - .set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.greater").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.less").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.maximum").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.minimum").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.multiply").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.power").set_attr("FRewriteTensorRT", RewriteElemwise); -TVM_REGISTER_OP("relax.subtract").set_attr("FRewriteTensorRT", RewriteElemwise); - -// math ops -TVM_REGISTER_OP("relax.argmax").set_attr("FRewriteTensorRT", RewriteArgmaxmin); -TVM_REGISTER_OP("relax.argmin").set_attr("FRewriteTensorRT", RewriteArgmaxmin); -TVM_REGISTER_OP("relax.broadcast_to") - .set_attr("FRewriteTensorRT", RewriteBroadcastTo); -TVM_REGISTER_OP("relax.expand_dims") - .set_attr("FRewriteTensorRT", RewriteShapeLike); -TVM_REGISTER_OP("relax.matmul").set_attr("FRewriteTensorRT", RewriteMatmul); -TVM_REGISTER_OP("relax.rsqrt").set_attr("FRewriteTensorRT", RewriteRsqrt); -TVM_REGISTER_OP("relax.squeeze").set_attr("FRewriteTensorRT", RewriteShapeLike); -TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", RewriteSplit); - -class TensorRTTransformer : public ExprMutator { - public: - explicit TensorRTTransformer(IRModule ctx_module, const ffi::String& config) - : ExprMutator(ctx_module) { - config_ = config; - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - if (const auto* op_node = call_node->op.as()) { - const auto& op = Downcast(ffi::GetRef(op_node)); - const auto& rewrite_map = Op::GetAttrMap("FRewriteTensorRT"); - if (rewrite_map.count(op)) { - const auto& call = ffi::GetRef(call_node); - FRewriteTensorRT f = rewrite_map[op]; - const auto& new_call = f(builder_, binding->var, call, new_calls_, config_); - if (new_call != call) { - ReEmitBinding(binding, builder_->Normalize(new_call)); - new_calls_.Set(binding->var, call); - } - } - } - if (!new_calls_.count(binding->var)) { - ExprMutator::VisitBinding_(binding, call_node); - } - } - - private: - ffi::Map new_calls_; - ffi::String config_; -}; - -Function TransformTensorRT(const Function& func, const IRModule& module, - const ffi::String& config) { - return Downcast(TensorRTTransformer(module, config).VisitExpr(func)); -} - -namespace transform { - -Pass TransformTensorRT(const ffi::String& config) { - auto pass_func = [=](Function f, IRModule m, PassContext pc) { - return relax::TransformTensorRT(f, m, config); - }; - return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.transform.TransformTensorRT", TransformTensorRT); -} - -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc deleted file mode 100644 index 4607b0d94bef..000000000000 --- a/src/contrib/msc/framework/torch/codegen.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/codegen.cc - */ -#include "codegen.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -void TorchCodeGen::CodeGenHeader() { - PyCodeGen::CodeGenHeader(); - stack_.line("import torch"); - stack_.line("from torch import nn"); - stack_.line("from torch.nn import functional"); -} - -void TorchCodeGen::CodeGenGraph() { - stack_.class_def(graph()->name + "(torch.nn.Module)"); - stack_.class_start(); - - // Write init - is_init_ = true; - stack_.func_def("__init__", "torch.nn.Module"); - if (config()->use_tools) { - stack_.func_decorator("msc_tools.wrap_step(\"build\",\"" + config()->tools_tag + "\")"); - } - stack_.func_arg("self", "torch.nn.Module"); - if (config()->use_plugin) { - stack_.func_arg("plugin", "Any"); - } - stack_.func_start() - .func_call("super") - .call_arg(graph()->name) - .call_arg("self") - .method_call("__init__"); - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (node->optype == "input") { - continue; - } - CodeGenNode(node, false); - } - stack_.func_end(); - - // Write forward - is_init_ = false; - stack_.func_def("forward", "List[torch.Tensor]"); - if (config()->use_tools) { - stack_.func_decorator("msc_tools.wrap_step(\"forward\",\"" + config()->tools_tag + "\")"); - } - stack_.func_arg("self", "torch.nn.Module"); - for (const auto& i : graph()->GetInputs()) { - const auto& pair = graph()->FindProducerAndIdx(i); - stack_.func_arg(IdxOutputBase(pair.first, pair.second), "torch.Tensor"); - } - stack_.func_start(); - if (config()->use_tools) { - stack_.comment("Define all weights"); - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - for (const auto& pair : node->weights) { - stack_.assign(IdxWeightBase(node, pair.first, false), "self." + pair.second->alias); - } - } - stack_.comment("End of define all weights").line(); - } - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (node->optype == "input") { - continue; - } - CodeGenNode(node, config()->use_tools); - } - ffi::Array idx_outputs; - for (const auto& o : graph()->GetOutputs()) { - const auto& pair = graph()->FindProducerAndIdx(o); - idx_outputs.push_back(IdxOutputBase(pair.first, pair.second, true)); - } - if (idx_outputs.size() == 1) { - stack_.assign("outputs", idx_outputs[0]); - } else { - stack_.assign("outputs", DocUtils::ToList(idx_outputs)); - } - stack_.func_end("outputs"); - stack_.class_end(); -} - -void TorchCodeGen::CodeGenInference() { - if (config()->use_plugin) { - stack_.comment("Import Plugin") - .line("from msc_plugin.torch import PluginManager") - .line() - .func_call("PluginManager", "plugin"); - } - stack_.comment("Build Model").func_call(graph()->name, "model"); - if (config()->use_plugin) { - stack_.call_arg("plugin"); - } - stack_.comment("Load weights") - .func_call("torch.load", "weights") - .call_arg(DocUtils::ToStr(graph()->name + ".pth")) - .func_call("load_state_dict", "", "model") - .call_arg("weights"); - if (config()->test_device == "gpu") { - stack_.func_call("to", "", "model").func_call("torch.device").call_arg("cuda").pop_nest(); - } - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.func_call("torch.from_numpy", IdxNodeBase(producer)) - .call_arg(DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); - } - stack_.func_call("model", "outputs"); - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.call_arg(IdxNodeBase(producer)); - if (config()->test_device == "gpu") { - stack_.method_call("to").func_call("torch.device").call_arg("cuda"); - } - } -} - -const ffi::Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { - const auto& ops_map = GetTorchOpCodes(); - auto it = ops_map->find(GetOpType(node)); - TVM_FFI_ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; - it->second->Config(node, config(), is_init_, prims()); - try { - return it->second->GetDocs(); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); - throw err; - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.framework.torch.GetTorchSources", - [](const MSCGraph& graph, const ffi::String& codegen_config, - const ffi::String& print_config) -> ffi::Map { - TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/torch/codegen.h b/src/contrib/msc/framework/torch/codegen.h deleted file mode 100644 index 1e5032309cb6..000000000000 --- a/src/contrib/msc/framework/torch/codegen.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/codegen.h - * \brief Torch codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/py_codegen.h" -#include "codegen_utils.h" -#include "torch_opcode.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TorchCodeGen : public PyCodeGen { - public: - /*! - * \brief The constructor of TorchCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit TorchCodeGen(const MSCGraph& graph, const std::string& config = "") - : PyCodeGen(graph, config) {} - - protected: - /*! \brief Stack the docs for the header*/ - void CodeGenHeader() final; - - /*! \brief Stack the docs for the graph*/ - void CodeGenGraph() final; - - /*! \brief Stack the docs for the graph inference*/ - void CodeGenInference() final; - - /*! \brief Get the docs for the op*/ - const ffi::Array GetOpCodes(const MSCJoint& node) final; - - /*! \brief Get tensor type of the framework*/ - const ffi::String TensorType() const final { return "torch.Tensor"; } - - private: - bool is_init_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_H_ diff --git a/src/contrib/msc/framework/torch/codegen_utils.h b/src/contrib/msc/framework/torch/codegen_utils.h deleted file mode 100644 index 5ddff5fc2164..000000000000 --- a/src/contrib/msc/framework/torch/codegen_utils.h +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/codegen_utils.h - * \brief Utils for torch codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_UTILS_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen helper for torch codegen - */ -class TorchCodeGenHelper : public BaseCodeGenHelper { - public: - /*! \brief Get describe for default node input*/ - const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, - const ffi::String& suffix = "", bool mark_exit = false) final { - if ((node->optype == "max" || node->optype == "min") && node->OutputAt(0)->Ndim() > 0) { - TVM_FFI_ICHECK(idx == 0) << "max and min op only support 1 outputs, get " << node; - return IdxNodeBase(node, prefix, suffix) + ".values"; - } - return BaseCodeGenHelper::IdxOutputBase(node, prefix, idx, suffix, mark_exit); - } -}; - -/*! - * \brief CodeGen config for torch codegen - */ -struct TorchCodeGenConfig { - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { CODEGEN_CONFIG_PARSE } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/framework/torch/config.h b/src/contrib/msc/framework/torch/config.h deleted file mode 100644 index b5a18e123ad4..000000000000 --- a/src/contrib/msc/framework/torch/config.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/config.h - * \brief Torch config for codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CONFIG_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CONFIG_H_ - -#include - -#include "../../core/codegen/base_codegen.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for torch codegen - */ -struct TorchCodeGenConfig { - bool is_training{false}; - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("is_training")); it != obj.end()) { - is_training = (*it).second.cast(); - } - CODEGEN_CONFIG_PARSE - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TORCH_CONFIG_H_ diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc deleted file mode 100644 index 7641f4c443f5..000000000000 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ /dev/null @@ -1,866 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/torch_opcode.cc - */ -#include "torch_opcode.h" - -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::Array TorchOpCode::GetDocs() { - stack_.Config(this); - if (is_init()) { - CodeGenInit(); - } else { - CodeGenForward(); - } - return stack_.GetDocs(); -} - -void TorchOpCode::CodeGenInit() { - if (module_name().size() > 0) { - stack_.op_call(); - } else { - stack_.comment("passby: implement by " + func_name()); - } -} - -void TorchOpCode::CodeGenForward() { stack_.op_call().op_inputs_arg(false); } - -const StrictListDoc TorchOpCode::GetPadding(const ffi::String& key) { - std::vector padding, src_padding; - TVM_FFI_ICHECK(node()->GetAttr(key, &src_padding)); - if (node()->optype == "nn.conv1d" || node()->optype == "msc.conv1d_bias") { - if (src_padding.size() == 2) { - TVM_FFI_ICHECK(src_padding[0] == src_padding[1]) - << "Only accept symmetric padding, get " << node(); - padding.push_back(src_padding[0]); - } else { - LOG_FATAL << "nn.conv1d with unexpected padding " << node(); - } - } else if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias" || - node()->optype == "nn.avg_pool2d" || node()->optype == "nn.max_pool2d") { - if (src_padding.size() == 4) { - TVM_FFI_ICHECK(src_padding[0] == src_padding[2] && src_padding[1] == src_padding[3]) - << "Only accept symmetric padding, get " << node(); - padding.push_back(src_padding[0]); - padding.push_back(src_padding[1]); - } else { - LOG_FATAL << "nn.conv2d/pool2d with unexpected padding " << node(); - } - } else { - LOG_FATAL << "Unexpected padding node" << node(); - } - return DocUtils::ToList(padding); -} - -#define TORCH_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const ffi::String& module_name, const ffi::String& func_name) \ - : TorchOpCode(module_name, func_name) {} - -class TorchAdaptivePoolCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchAdaptivePoolCodeGen); - - protected: - void CodeGenInit() final { stack_.op_call().op_list_arg("output_size"); } -}; - -class TorchAstypeCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchAstypeCodeGen); - - protected: - void CodeGenForward() final { - stack_.assign(IdxNode(), IdxInput()).method_call("to").op_dtype_arg(node()->OutputAt(0)->dtype); - } -}; - -class TorchAttentionCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchAttentionCodeGen); - - protected: - void CodeGenForward() final { - std::string causal_mask; - stack_.op_call().op_inputs_arg(false); - if (node()->GetAttr("causal_mask", &causal_mask)) { - if (causal_mask.size() > 0) { - stack_.call_arg(true, "is_causal"); - } - } - } -}; - -class TorchAxesCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchAxesCodeGen); - - protected: - void CodeGenInit() final { - if (module_name().size() > 0) { - const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_call().op_list_arg(key, ""); - } else { - TorchOpCode::CodeGenInit(); - } - } - - void CodeGenForward() final { - if (module_name().size() > 0) { - TorchOpCode::CodeGenForward(); - } else { - const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_call().op_input_arg().op_list_arg(key, ""); - } - } -}; - -class TorchAxisCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchAxisCodeGen); - - protected: - void CodeGenInit() final { - if (module_name().size() > 0) { - stack_.op_call().op_arg("axis", "dim"); - } else { - TorchOpCode::CodeGenInit(); - } - } - - void CodeGenForward() final { - if (module_name().size() > 0) { - TorchOpCode::CodeGenForward(); - } else { - stack_.op_call().op_input_arg().op_arg("axis", "dim"); - } - } -}; - -class TorchBatchNormCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchBatchNormCodeGen); - - protected: - void CodeGenInit() final { - TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) - << "Only support center and scale batchnorm, get " << node(); - const auto& gamma = node()->WeightAt("gamma"); - stack_.op_call().call_arg(gamma->DimAt(0), "num_features").op_arg("epsilon", "eps"); - } - - void CodeGenForward() final { - if (config()->use_tools) { - stack_.op_call(func_name()) - .op_input_arg() - .op_weight_arg("mean") - .op_weight_arg("var") - .op_weight_arg("gamma") - .op_weight_arg("beta") - .call_arg(DocUtils::ToAttrAccess(module_ref(), "training")) - .call_arg(DocUtils::ToAttrAccess(module_ref(), "momentum")) - .call_arg(DocUtils::ToAttrAccess(module_ref(), "eps")); - } else { - TorchOpCode::CodeGenForward(); - } - } -}; - -class TorchBroadcastToCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchBroadcastToCodeGen); - - protected: - void CodeGenForward() final { - stack_.assign(IdxNode(), IdxInput()).method_call("expand").op_list_arg("shape", ""); - } -}; - -class TorchClipCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchClipCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call().op_input_arg().op_arg("min").op_arg("max"); - } -}; - -class TorchConcatCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchConcatCodeGen); - - protected: - void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } -}; - -class TorchStackCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen); - - protected: - void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } -}; - -class TorchConstantCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen); - - protected: - void CodeGenInit() final { - const auto& dtype = node()->OutputAt(0)->DTypeName(); - const auto& ref_name = StringUtils::Replace(node()->name, ".", "_"); - if (node()->HasAttr("scalar")) { - if (dtype == "int32") { - stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); - } else if (dtype == "int64") { - stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); - } else if (dtype == "float32") { - stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); - } - } else if (dtype == "bool") { - stack_.func_call("register_buffer", "", "self") - .call_arg(DocUtils::ToStr(ref_name)) - .inplace_start("torch.BoolTensor") - .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) - .inplace_end(); - } else if (dtype == "int32") { - stack_.func_call("register_buffer", "", "self") - .call_arg(DocUtils::ToStr(ref_name)) - .inplace_start("torch.IntTensor") - .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) - .inplace_end(); - } else if (dtype == "int64") { - stack_.func_call("register_buffer", "", "self") - .call_arg(DocUtils::ToStr(ref_name)) - .inplace_start("torch.LongTensor") - .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) - .inplace_end(); - } else { - stack_.func_call("torch.Tensor", "data") - .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) - .op_call() - .call_arg("data"); - } - } - - void CodeGenForward() final { - if (config()->use_tools) { - stack_.assign(IdxNode(), IdxWeight("const", true)); - } else { - stack_.assign(IdxNode(), module_ref()); - } - } -}; - -class TorchConvCodeGen : public TorchOpCode { - public: - TorchConvCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) - : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} - - protected: - void CodeGenInit() final { - const auto& weight = node()->WeightAt("weight"); - std::vector kernel_size; - for (size_t i = 0; i < weight->Ndim(); i++) { - if (weight->layout[i].name() == "I" || weight->layout[i].name() == "O") { - continue; - } - kernel_size.push_back(weight->DimAt(i)->value); - } - stack_.op_call() - .call_arg(weight->DimAt("I"), "in_channels") - .call_arg(weight->DimAt("O"), "out_channels") - .call_arg(DocUtils::ToList(kernel_size), "kernel_size") - .op_list_arg("strides", "stride") - .call_arg(GetPadding(), "padding") - .op_list_arg("dilation") - .op_arg("groups") - .call_arg(use_bias_, "bias"); - } - - void CodeGenForward() final { - if (config()->use_tools) { - stack_.op_call(func_name()).op_input_arg().op_weight_arg("weight"); - if (use_bias_) { - stack_.op_weight_arg("bias"); - } else { - stack_.call_arg("None"); - } - stack_.call_arg(DocUtils::ToAttrAccess(module_ref(), "stride")) - .call_arg(DocUtils::ToAttrAccess(module_ref(), "padding")) - .call_arg(DocUtils::ToAttrAccess(module_ref(), "dilation")) - .call_arg(DocUtils::ToAttrAccess(module_ref(), "groups")); - } else { - TorchOpCode::CodeGenForward(); - } - } - - private: - bool use_bias_; -}; - -class TorchCumsumCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchCumsumCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call() - .op_input_arg() - .op_arg("axis", "dim") - .op_dtype_arg(node()->OutputAt(0)->dtype, "dtype"); - } -}; - -class TorchEmbeddingCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchEmbeddingCodeGen); - - protected: - void CodeGenInit() final { - const auto& weight = node()->WeightAt("weight"); - stack_.op_call() - .call_arg(weight->DimAt(0), "num_embeddings") - .call_arg(weight->DimAt(1), "embedding_dim"); - } -}; - -class TorchExpandDimsCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchExpandDimsCodeGen); - - protected: - void CodeGenForward() final { - const auto& axes = node()->GetTypeArrayAttr("axis"); - ffi::String idx_input = IdxInput(); - for (size_t i = 0; i < axes.size(); i++) { - ffi::String idx_out = IdxNode(); - if (i < axes.size() - 1) { - idx_out = idx_out + "_" + std::to_string(i); - } - stack_.op_call().call_arg(idx_input).call_arg(axes[i], "dim"); - idx_input = idx_out; - } - } -}; - -class TorchFullCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchFullCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call() - .op_list_arg("shape", "size") - .op_input_arg(0, "fill_value") - .op_dtype_arg(node()->OutputAt(0)->dtype, "dtype"); - } -}; - -class TorchGetItemCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchGetItemCodeGen); - - protected: - void CodeGenForward() final { - stack_.assign(IdxNode(), IdxInput(node()->GetTypeAttr("index"))); - } -}; - -class TorchGroupNormCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchGroupNormCodeGen); - - protected: - void CodeGenInit() final { - TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) - << "Only support center and scale batchnorm, get " << node(); - int channel_axis = node()->GetTypeAttr("channel_axis"); - stack_.op_call() - .op_arg("num_groups") - .call_arg(node()->InputAt(0)->DimAt(channel_axis), "num_channels") - .op_arg("epsilon", "eps"); - } -}; - -class TorchLayerNormCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchLayerNormCodeGen); - - protected: - void CodeGenInit() final { - TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) - << "Only support center and scale batchnorm, get " << node(); - const auto& axes = - CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); - ffi::Array normalized_shape; - for (const auto& a : axes) { - normalized_shape.push_back(node()->InputAt(0)->DimAt(a)); - } - stack_.op_call() - .call_arg(DocUtils::ToList(normalized_shape), "normalized_shape") - .op_arg("epsilon", "eps"); - } -}; - -class TorchLinearCodeGen : public TorchOpCode { - public: - TorchLinearCodeGen(const ffi::String& module_name, const ffi::String& func_name, bool use_bias) - : TorchOpCode(module_name, func_name), use_bias_(use_bias) {} - - protected: - void CodeGenInit() final { - const auto& weight = node()->WeightAt("weight"); - stack_.op_call() - .call_arg(weight->DimAt("I"), "in_features") - .call_arg(weight->DimAt("O"), "out_features") - .call_arg(use_bias_, "bias"); - } - - void CodeGenForward() final { - if (config()->use_tools) { - stack_.op_call(func_name()).op_input_arg().op_weight_arg("weight"); - if (use_bias_) { - stack_.op_weight_arg("bias"); - } else { - stack_.call_arg("None"); - } - } else { - TorchOpCode::CodeGenForward(); - } - } - - private: - bool use_bias_; -}; - -class TorchNllLossCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchNllLossCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call().op_inputs_arg(false).op_str_arg("reduction").op_arg("ignore_index"); - } -}; - -class TorchPoolCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchPoolCodeGen); - - protected: - void CodeGenInit() final { - stack_.op_call() - .op_list_arg("pool_size", "kernel_size") - .op_list_arg("strides", "stride") - .call_arg(GetPadding(), "padding") - .op_arg("ceil_mode"); - if (node()->optype == "nn.max_pool2d") { - stack_.op_list_arg("dilation"); - } - } -}; - -class TorchPermuteDimsCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchPermuteDimsCodeGen) - - protected: - void CodeGenForward() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { - axes.push_back(i - 1); - } - } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes)); - } -}; - -class TorchReduceAxisCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchReduceAxisCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call().op_input_arg(); - int axis; - std::vector axes; - bool has_axis = false; - if (node()->GetAttr("axis", &axis)) { - has_axis = true; - } else if (node()->GetAttr("axis", &axes)) { - axis = axes[0]; - has_axis = true; - } - if (has_axis) { - stack_.call_arg(axis, "dim"); - } - stack_.op_arg("keepdims", "keepdim"); - } -}; - -class TorchReduceAxesCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchReduceAxesCodeGen); - - protected: - void CodeGenForward() final { - stack_.op_call().op_input_arg(); - std::vector axes; - bool has_axes = false; - if (node()->GetAttr("axis", &axes) || node()->GetAttr("axes", &axes)) { - has_axes = true; - } - if (has_axes) { - stack_.call_arg(DocUtils::ToList(axes), "dim").op_arg("keepdims", "keepdim"); - } - } -}; - -class TorchRepeatCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchRepeatCodeGen); - - protected: - void CodeGenForward() final { - int repeat = node()->GetTypeAttr("repeats"); - int axis = node()->GetTypeAttr("axis"); - std::vector repeats; - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - if (i == static_cast(axis)) { - repeats.push_back(repeat); - } else { - repeats.push_back(1); - } - } - stack_.assign(IdxNode(), IdxInput()) - .method_call("repeat") - .call_arg(DocUtils::ToList(repeats), ""); - } -}; - -class TorchReshapeCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchReshapeCodeGen); - - protected: - void CodeGenForward() final { - ffi::Array shape = node()->OutputAt(0)->shape; - const auto& out_layout = node()->OutputAt(0)->layout; - if (out_layout.defined()) { - int32_t batch_dim = out_layout.IndexOf(tvm::tir::LayoutAxis::Get("N")); - if (batch_dim >= 0) { - shape.Set(batch_dim, Integer(-1)); - } - } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(shape)); - } -}; - -class TorchResize2dCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchResize2dCodeGen); - - protected: - void CodeGenForward() final { - const auto& method = node()->GetTypeAttr("method"); - ffi::String v_method; - if (method == "nearest_neighbor") { - v_method = "nearest"; - } else { - TVM_FFI_THROW(InternalError) << "Unexpected resize2d method " << method; - } - stack_.op_call().op_input_arg().op_list_arg("size").call_arg(DocUtils::ToStr(v_method), - "mode"); - } -}; - -class TorchShapeCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchShapeCodeGen); - - protected: - void CodeGenForward() final { - if (node()->inputs.size() == 0) { - stack_.op_call().op_list_arg("shape", ""); - } else { - stack_.assign(IdxNode(), IdxInput()).method_call("size"); - } - } -}; - -class TorchSimpleCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchSimpleCodeGen); -}; - -class TorchScatterElementsCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchScatterElementsCodeGen) - - protected: - void CodeGenForward() final { - if (node()->InputAt(1)->DTypeName() == "int32") { - stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); - } - stack_.op_call() - .op_input_arg() - .op_arg("axis", "dim") - .op_input_arg(1, "index") - .op_input_arg(2, "src"); - } -}; - -class TorchScatterNDCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchScatterNDCodeGen) - - protected: - void CodeGenForward() final { - if (node()->InputAt(1)->DTypeName() == "int32") { - stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); - } - // relax add extra dim for indices - if (node()->InputAt(1)->Ndim() == node()->OutputAt(0)->Ndim()) { - stack_.func_call("squeeze", IdxInput(1), IdxInput(1)).call_arg(-1); - } - stack_.assign(DocUtils::ToIndex(IdxInput(0), IdxInput(1)), IdxInput(2)) - .assign(IdxNode(), IdxInput(0)); - } -}; - -class TorchSplitCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchSplitCodeGen) - - protected: - void CodeGenForward() final { - stack_.op_call().op_input_arg(); - std::vector indices; - int axis = node()->GetTypeAttr("axis"); - for (size_t i = 0; i < node()->outputs.size(); i++) { - indices.push_back(node()->OutputAt(i)->DimAt(axis)->value); - } - stack_.call_arg(DocUtils::ToList(indices), "split_size_or_sections").op_arg("axis", "dim"); - } -}; - -class TorchStridedSliceCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchStridedSliceCodeGen); - - protected: - void CodeGenForward() final { - const auto& begin = node()->GetTypeArrayAttr("begin"); - const auto& end = node()->GetTypeArrayAttr("end"); - std::vector strides; - if (!node()->GetAttr("strides", &strides)) { - strides = std::vector(begin.size(), 1); - } - const auto& axes = - CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); - std::unordered_map axes_map; - for (size_t i = 0; i < axes.size(); i++) { - axes_map[axes[i]] = i; - } - ffi::Array slice; - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - if (axes_map.count(i)) { - size_t idx = axes_map[i]; - slice.push_back(std::to_string(begin[idx]) + ":" + std::to_string(end[idx]) + ":" + - std::to_string(strides[idx])); - } else { - slice.push_back(":"); - } - } - stack_.assign(IdxNode(), DocUtils::ToIndices(IdxInput(), slice)); - } -}; - -class TorchTakeCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen) - - protected: - void CodeGenForward() final { - if (node()->InputAt(1)->DTypeName() == "int32") { - stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64"); - } - stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1))); - } -}; - -class TorchTriCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen) - - protected: - void CodeGenForward() final { stack_.op_call().op_input_arg().op_arg("k", "diagonal"); } -}; - -class TorchTupleCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchTupleCodeGen) - - protected: - void CodeGenForward() final { stack_.op_call().op_inputs_arg(); } -}; - -class TorchPluginOpCodeGen : public TorchOpCode { - TORCH_OP_CODEGEN_METHODS(TorchPluginOpCodeGen) - - protected: - void CodeGenInit() final { - const auto& plugin = GetPlugin(node()->optype); - stack_.op_call("plugin." + node()->optype); - for (const auto& a : plugin->attrs) { - stack_.call_arg(GetAttrDoc(a->name, a->type), a->name); - } - } - - void CodeGenForward() final { stack_.op_call().op_inputs_arg(false); } -}; - -const std::shared_ptr>> -GetTorchOpCodes() { - static auto map = - std::make_shared>>(); - if (!map->empty()) return map; - - // simple ops - map->emplace("abs", std::make_shared("", "torch.abs")); - map->emplace("acos", std::make_shared("", "torch.acos")); - map->emplace("acosh", std::make_shared("", "torch.acosh")); - map->emplace("add", std::make_shared("", "torch.add")); - map->emplace("asin", std::make_shared("", "torch.asin")); - map->emplace("asinh", std::make_shared("", "torch.asinh")); - map->emplace("atan", std::make_shared("", "torch.atan")); - map->emplace("atanh", std::make_shared("", "torch.atanh")); - map->emplace("bitwise_and", std::make_shared("", "torch.bitwise_and")); - map->emplace("bitwise_not", std::make_shared("", "torch.bitwise_not")); - map->emplace("bitwise_or", std::make_shared("", "torch.bitwise_or")); - map->emplace("bitwise_xor", std::make_shared("", "torch.bitwise_xor")); - map->emplace("ceil", std::make_shared("", "torch.ceil")); - map->emplace("cos", std::make_shared("", "torch.cos")); - map->emplace("cosh", std::make_shared("", "torch.cosh")); - map->emplace("divide", std::make_shared("", "torch.divide")); - map->emplace("exp", std::make_shared("", "torch.exp")); - map->emplace("equal", std::make_shared("", "torch.equal")); - map->emplace("floor", std::make_shared("", "torch.floor")); - map->emplace("floor_divide", std::make_shared("", "torch.floor_divide")); - map->emplace("greater", std::make_shared("", "torch.greater")); - map->emplace("greater_equal", std::make_shared("", "torch.greater_equal")); - map->emplace("less", std::make_shared("", "torch.less")); - map->emplace("less_equal", std::make_shared("", "torch.less_equal")); - map->emplace("log", std::make_shared("", "torch.log")); - map->emplace("logical_and", std::make_shared("", "torch.logical_and")); - map->emplace("logical_or", std::make_shared("", "torch.logical_or")); - map->emplace("logical_xor", std::make_shared("", "torch.logical_xor")); - map->emplace("matmul", std::make_shared("", "torch.matmul")); - map->emplace("maximum", std::make_shared("", "torch.maximum")); - map->emplace("minimum", std::make_shared("", "torch.minimum")); - map->emplace("multiply", std::make_shared("", "torch.multiply")); - map->emplace("negative", std::make_shared("", "torch.negative")); - map->emplace("not_equal", std::make_shared("", "torch.not_equal")); - map->emplace("power", std::make_shared("", "torch.pow")); - map->emplace("round", std::make_shared("", "torch.round")); - map->emplace("rsqrt", std::make_shared("", "torch.rsqrt")); - map->emplace("sigmoid", std::make_shared("", "torch.sigmoid")); - map->emplace("sign", std::make_shared("", "torch.sign")); - map->emplace("sin", std::make_shared("", "torch.sin")); - map->emplace("sinh", std::make_shared("", "torch.sinh")); - map->emplace("square", std::make_shared("", "torch.square")); - map->emplace("sqrt", std::make_shared("", "torch.sqrt")); - map->emplace("subtract", std::make_shared("", "torch.subtract")); - map->emplace("tan", std::make_shared("", "torch.tan")); - map->emplace("tanh", std::make_shared("", "torch.tanh")); - map->emplace("where", std::make_shared("", "torch.where")); - - // reduce ops - map->emplace("max", std::make_shared("", "torch.max")); - map->emplace("min", std::make_shared("", "torch.min")); - map->emplace("mean", std::make_shared("", "torch.mean")); - map->emplace("sum", std::make_shared("", "torch.sum")); - map->emplace("argmax", std::make_shared("", "torch.argmax")); - map->emplace("argmin", std::make_shared("", "torch.argmin")); - map->emplace("prod", std::make_shared("", "torch.prod")); - map->emplace("std", std::make_shared("", "torch.std")); - - // axis && axes ops - map->emplace("nn.log_softmax", - std::make_shared("nn.LogSoftmax", "functional.log_softmax")); - map->emplace("nn.softmax", - std::make_shared("nn.Softmax", "functional.softmax")); - map->emplace("squeeze", std::make_shared("", "torch.squeeze")); - - // math ops - map->emplace("astype", std::make_shared("", "to")); - map->emplace("broadcast_to", std::make_shared("", "expand")); - map->emplace("clip", std::make_shared("", "torch.clamp")); - map->emplace("concat", std::make_shared("", "torch.cat")); - map->emplace("cumsum", std::make_shared("", "torch.cumsum")); - map->emplace("expand_dims", std::make_shared("", "torch.unsqueeze")); - map->emplace("permute_dims", std::make_shared("", "torch.permute")); - map->emplace("repeat", std::make_shared("", "repeat")); - map->emplace("reshape", std::make_shared("", "torch.reshape")); - map->emplace("scatter_elements", - std::make_shared("", "torch.scatter")); - map->emplace("scatter_nd", std::make_shared("", "")); - map->emplace("split", std::make_shared("", "torch.split")); - map->emplace("stack", std::make_shared("", "torch.stack")); - map->emplace("strided_slice", std::make_shared("", "")); - map->emplace("take", std::make_shared("", "")); - - // create ops - map->emplace("constant", std::make_shared("nn.Parameter", "")); - map->emplace("full", std::make_shared("", "torch.full")); - map->emplace("tril", std::make_shared("", "torch.tril")); - map->emplace("triu", std::make_shared("", "torch.triu")); - - // nn ops - map->emplace("nn.adaptive_avg_pool2d", - std::make_shared("nn.AdaptiveAvgPool2d", - "functional.adaptive_avg_pool2d")); - map->emplace("nn.avg_pool2d", - std::make_shared("nn.AvgPool2d", "functional.avg_pool2d")); - map->emplace("nn.batch_norm", - std::make_shared("nn.BatchNorm2d", "functional.batch_norm")); - map->emplace("nn.conv1d", - std::make_shared("nn.Conv1d", "functional.conv1d", false)); - map->emplace("nn.conv2d", - std::make_shared("nn.Conv2d", "functional.conv2d", false)); - map->emplace("nn.gelu", std::make_shared("nn.GELU", "functional.gelu")); - map->emplace("nn.group_norm", - std::make_shared("nn.GroupNorm", "functional.group_norm")); - map->emplace("nn.layer_norm", - std::make_shared("nn.LayerNorm", "functional.layer_norm")); - map->emplace("nn.linear", - std::make_shared("nn.Linear", "functional.linear", false)); - map->emplace("nn.max_pool2d", - std::make_shared("nn.MaxPool2d", "functional.max_pool2d")); - map->emplace("nn.nll_loss", std::make_shared("", "functional.nll_loss")); - map->emplace("nn.relu", std::make_shared("nn.ReLU", "functional.relu")); - map->emplace("nn.silu", std::make_shared("nn.SiLU", "functional.silu")); - - // image ops - map->emplace("image.resize2d", - std::make_shared("", "torch.nn.functional.interpolate")); - - // special op - map->emplace("get_item", std::make_shared("", "")); - map->emplace("shape", std::make_shared("", "torch.Size")); - map->emplace("tuple", std::make_shared("", "tuple")); - map->emplace("plugin", std::make_shared("Plugin", "")); - - // msc ops - map->emplace("msc.attention", std::make_shared( - "", "functional.scaled_dot_product_attention")); - map->emplace("msc.conv1d_bias", - std::make_shared("nn.Conv1d", "functional.conv1d", true)); - map->emplace("msc.conv2d_bias", - std::make_shared("nn.Conv2d", "functional.conv2d", true)); - map->emplace("msc.embedding", - std::make_shared("nn.Embedding", "functional.embedding")); - map->emplace("msc.gelu", std::make_shared("nn.GELU", "functional.gelu")); - map->emplace("msc.linear", - std::make_shared("nn.Linear", "functional.linear", false)); - map->emplace("msc.linear_bias", - std::make_shared("nn.Linear", "functional.linear", true)); - return map; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h deleted file mode 100644 index e732e502ce31..000000000000 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/torch/torch_opcode.h - * \brief Torch codegen for MSCJoint. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TORCH_TORCH_OPCODE_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TORCH_TORCH_OPCODE_H_ - -#include -#include -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class TorchOpCode; -typedef OpCodeStack TorchOpCodeStack; - -/*! - * \brief CodeGen for torch op - */ -class TorchOpCode : public BaseOpCode { - public: - /*! - * \brief The constructor of BaseOpDocsifier - * \param func_name the function name for the node. - * \param config the config json for the node. - */ - explicit TorchOpCode(const ffi::String& module_name, const ffi::String& func_name) - : BaseOpCode(func_name) { - module_name_ = module_name; - } - - /*! \brief Config the TorchOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config, bool is_init, - const ffi::Map& prims) { - BaseOpCode::Config(node, config, prims); - is_init_ = is_init; - module_ref_ = "self." + StringUtils::Replace(node->name, ".", "_"); - } - - /*! \brief Get return describe for default node*/ - const ffi::String IdxNode() final { - return is_init_ ? module_ref_ : BaseOpCode::IdxNode(); - }; - - /*! \brief Get dtype string*/ - const ffi::String DType(const DataType& dtype) final { - return "torch." + BaseOpCode::DType(dtype); - } - - /*! \brief Get func_name for the default node*/ - const ffi::String callee_name() final { - if (is_init_) { - return module_name_; - } - if (module_name_.size() > 0) { - return module_ref_; - } - return BaseOpCode::callee_name(); - } - - /*! \brief Convert node to docs*/ - const ffi::Array GetDocs() final; - - protected: - TorchOpCodeStack stack_; - - /*! \brief Convert op build*/ - virtual void CodeGenInit(); - - /*! \brief Convert op build*/ - virtual void CodeGenForward(); - - /*! \brief Get the padding from op*/ - const StrictListDoc GetPadding(const ffi::String& key = "padding"); - - /*! \brief Get the is_init_ of codegen*/ - bool is_init() { return is_init_; } - - /*! \brief Get the module_name of codegen*/ - const ffi::String module_name() { return module_name_; } - - /*! \brief Get the module_ref of codegen*/ - const ffi::String module_ref() { return module_ref_; } - - private: - bool is_init_; - ffi::String module_name_; - ffi::String module_ref_; -}; - -/*! - * \brief Get the map of available TorchOpCode, use optype as key - * \return Map of - */ -const std::shared_ptr>> -GetTorchOpCodes(); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TORCH_TORCH_OPCODE_H_ diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc deleted file mode 100644 index ab571e51cc93..000000000000 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/codegen.cc - */ -#include "codegen.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -void RelaxCodeGen::CodeGenHeader() { - PyCodeGen::CodeGenHeader(); - stack_.line("from tvm import relax"); -} - -void RelaxCodeGen::CodeGenGraph() { - stack_.func_def(graph()->name, "tvm.IRModule"); - ffi::Array idx_inputs; - for (const auto& i : graph()->GetInputs()) { - const auto& pair = graph()->FindProducerAndIdx(i); - const auto& idx_input = IdxOutputBase(pair.first, pair.second); - stack_.func_arg(idx_input, "relax.Var"); - idx_inputs.push_back(idx_input); - } - if (config()->use_plugin) { - stack_.func_arg("plugin", "Any"); - } - stack_.func_start().assign("inputs", DocUtils::ToList(idx_inputs, true)); - // define weights - stack_.comment("Define the weights"); - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - for (const auto& pair : node->weights) { - const auto& idx_weight = IdxWeightBase(node, pair.first, false); - stack_.func_call("relax.Var", idx_weight) - .call_arg(DocUtils::ToStr(pair.second->name)) - .func_call("relax.TensorStructInfo") - .call_arg(DocUtils::ToList(pair.second->shape, true), "") - .call_arg(DocUtils::ToStr(pair.second->DTypeName())) - .pop_nest() - .func_call("append", "", "inputs") - .call_arg(idx_weight); - } - } - stack_.comment("Define the module"); - stack_.func_call("relax.BlockBuilder", "block_builder") - .scope_start("block_builder.function(name=\"" + graph()->name + "\", params=inputs.copy())"); - if (config()->use_tools) { - stack_.func_call("msc_tools.execute_step") - .call_arg(DocUtils::ToStr("before_build")) - .call_arg("block_builder"); - } - for (const auto& n : graph()->node_names) { - const auto& node = graph()->FindNode(n); - if (node->optype == "input") { - continue; - } - int scope_level = CompareScope(node); - if (scope_level == -1) { - stack_.scope_end(); - } - CodeGenNode(node, config()->use_tools); - } - if (scopes().size() > 1) { - // end left scopes - for (size_t i = 0; i < scopes().size() - 1; i++) { - stack_.scope_end(); - } - } - // mark outputs - stack_.comment("Emit the outputs"); - ffi::Array idx_exits; - - for (const auto& e : graph()->GetExits()) { - const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : ""); - if (config()->use_tools) { - if (e->outputs.size() > 1) { - ffi::Array tuple_outputs; - for (size_t o_idx = 0; o_idx < e->outputs.size(); o_idx++) { - const auto& t_output = IdxOutputBase(e, o_idx, true); - tuple_outputs.push_back(t_output); - } - stack_.func_call("relax.Tuple", idx_exit).call_arg(DocUtils::ToList(tuple_outputs)); - stack_.func_call("emit", idx_exit, "block_builder").call_arg(idx_exit); - stack_.call_arg(DocUtils::ToStr(e->name + "_exit"), "name_hint"); - } - } - idx_exits.push_back(idx_exit); - } - - if (config()->use_tools) { - stack_.func_call("msc_tools.execute_step", "output").call_arg(DocUtils::ToStr("after_build")); - if (idx_exits.size() == 1) { - stack_.call_arg(idx_exits[0]); - } else { - stack_.call_arg(DocUtils::ToList(idx_exits)); - } - } - stack_.func_call("emit_func_output", "", "block_builder"); - if (config()->use_tools) { - stack_.call_arg("output"); - } else if (idx_exits.size() == 1) { - stack_.call_arg(idx_exits[0]); - } else { - stack_.call_arg(DocUtils::ToList(idx_exits)); - } - stack_.scope_end().func_call("finalize", "mod", "block_builder").func_end("mod"); -} - -void RelaxCodeGen::CodeGenInference() { - if (config()->use_plugin) { - stack_.comment("Import Plugin") - .line("from msc_plugin.tvm import PluginManager") - .line() - .func_call("PluginManager", "plugin"); - } - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.func_call("relax.Var", IdxNodeBase(producer)) - .call_arg(DocUtils::ToStr(i->alias)) - .func_call("relax.TensorStructInfo") - .call_arg(DocUtils::ToList(i->shape)) - .call_arg(DocUtils::ToStr(i->DTypeName())) - .pop_nest(); - } - stack_.comment("Build Module").func_call(graph()->name, "mod"); - if (config()->use_plugin) { - stack_.call_arg("plugin"); - } - for (const auto& i : graph()->GetInputs()) { - const auto& producer = graph()->FindProducer(i); - stack_.call_arg(IdxNodeBase(producer)); - } - ffi::String target, device; - if (config()->test_device == "cpu") { - target = "llvm"; - device = "tvm.cpu()"; - } else if (config()->test_device == "gpu") { - target = "cuda"; - device = "tvm.cuda()"; - } - stack_.comment("Load weights") - .scope_start("open(\"" + graph()->name + "_params.bin\", \"rb\")", "f") - .func_call("tvm.runtime.load_param_dict", "params") - .inplace_start("read", "", "f") - .inplace_end() - .scope_end() - .func_call("tvm.relax.transform.BindParams", "bind_params") - .call_arg(DocUtils::ToStr("main")) - .call_arg("params") - .func_call("bind_params", "mod") - .call_arg("mod") - .func_call("tvm.target.Target", "target") - .call_arg(DocUtils::ToStr(target)) - .func_call("tvm.relax.transform.LegalizeOps()", "mod") - .call_arg("mod") - .scope_start("tvm.transform.PassContext(opt_level=3)") - .func_call("relax.build", "ex") - .call_arg("mod") - .call_arg("target") - .func_call("relax.VirtualMachine", "vm") - .call_arg("ex") - .call_arg(device) - .scope_end() - .assign("f_main", DocUtils::ToIndex("vm", DocUtils::ToStr("main"))) - .func_call("f_main", "outputs"); - for (const auto& i : graph()->GetInputs()) { - stack_.call_arg(DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); - } -} - -const ffi::String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { - if (prim->optype == "shape") { - const auto& producer = graph()->FindNode(prim->GetTypeAttr("producer")); - int out_idx = prim->GetTypeAttr("out_idx"); - const auto& dim = prim->GetTypeAttr("dim"); - return IdxOutputBase(producer, out_idx) + ".struct_info.shape[" + dim + "]"; - } - return PyCodeGen::DescribePrim(prim); -} - -const ffi::Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { - const auto& ops_map = GetRelaxOpCodes(); - auto it = ops_map->find(GetOpType(node)); - TVM_FFI_ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; - it->second->Config(node, config(), prims()); - try { - return it->second->GetDocs(); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to get docs for " << node << " : " << err.what(); - throw err; - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.framework.tvm.GetRelaxSources", - [](const MSCGraph& graph, const ffi::String& codegen_config, - const ffi::String& print_config) -> ffi::Map { - RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); - codegen.Init(); - return codegen.GetSources(print_config); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tvm/codegen.h b/src/contrib/msc/framework/tvm/codegen.h deleted file mode 100644 index 0874e21acd4d..000000000000 --- a/src/contrib/msc/framework/tvm/codegen.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/codegen.h - * \brief Relax codegen for MSCGraph. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/py_codegen.h" -#include "codegen_utils.h" -#include "relax_opcode.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class RelaxCodeGen : public PyCodeGen { - public: - /*! - * \brief The constructor of RelaxCodeGen - * \param graph the graph to be generated. - * \param config the options for codegen. - */ - explicit RelaxCodeGen(const MSCGraph& graph, const std::string& config = "") - : PyCodeGen(graph, config) {} - - protected: - /*! \brief Stack the docs for the header*/ - void CodeGenHeader() final; - - /*! \brief Stack the docs for the graph*/ - void CodeGenGraph() final; - - /*! \brief Stack the docs for the graph inference*/ - void CodeGenInference() final; - - /*! \brief Describe the prim*/ - const ffi::String DescribePrim(const MSCPrim& prim) final; - - /*! \brief Get the docs for the op*/ - const ffi::Array GetOpCodes(const MSCJoint& node) final; - - /*! \brief Get tensor type of the framework*/ - const ffi::String TensorType() const final { return "relax.Expr"; } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm - -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_H_ diff --git a/src/contrib/msc/framework/tvm/codegen_utils.h b/src/contrib/msc/framework/tvm/codegen_utils.h deleted file mode 100644 index 9e822cc97c3f..000000000000 --- a/src/contrib/msc/framework/tvm/codegen_utils.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/codegen_utils.h - * \brief Utils for TVM codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_UTILS_H_ - -#include - -#include "../../core/codegen/base_codegen.h" -#include "../../core/codegen/codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen helper for relax codegen - */ -class RelaxCodeGenHelper : public BaseCodeGenHelper {}; - -/*! - * \brief CodeGen config for tvm codegen - */ -struct RelaxCodeGenConfig { - bool explicit_name{true}; - bool from_relay{false}; - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("explicit_name")); it != obj.end()) { - explicit_name = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("from_relay")); it != obj.end()) { - from_relay = (*it).second.cast(); - } - CODEGEN_CONFIG_PARSE - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TVM_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/framework/tvm/config.h b/src/contrib/msc/framework/tvm/config.h deleted file mode 100644 index 3426730f68dd..000000000000 --- a/src/contrib/msc/framework/tvm/config.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/config.h - * \brief Relax config for codegen. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TVM_CONFIG_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TVM_CONFIG_H_ - -#include - -#include "../../core/codegen/base_codegen.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for tvm codegen - */ -struct RelaxCodeGenConfig { - bool explicit_name{true}; - bool from_relay{false}; - CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("explicit_name")); it != obj.end()) { - explicit_name = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("from_relay")); it != obj.end()) { - from_relay = (*it).second.cast(); - } - CODEGEN_CONFIG_PARSE - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TVM_CONFIG_H_ diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc deleted file mode 100644 index 846b09da8329..000000000000 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ /dev/null @@ -1,881 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/relax_opcode.cc - */ -#include "relax_opcode.h" - -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -const ffi::Array RelaxOpCode::GetDocs() { - stack_.Config(this); - CodeGenBuild(); - bool emit_var = true; - if (node()->optype == "input" || node()->optype == "constant" || node()->optype == "shape") { - emit_var = false; - } - if (emit_var) { - const auto& name = config()->explicit_name ? node()->name : ""; - BuilderEmit(IdxNode(), name); - } - return stack_.GetDocs(); -} - -void RelaxOpCode::BuilderEmit(const ffi::String& ret, const ffi::String& name) { - stack_.func_call("block_builder.emit", ret).call_arg(ret); - if (name.size() > 0) { - stack_.call_arg(DocUtils::ToStr(name), "name_hint"); - } -} - -const ExprDoc RelaxOpCode::GetOutDtype(const ffi::String& key, int input_idx) { - if (config()->use_tools && input_idx >= 0 && - node()->inputs.size() > static_cast(input_idx)) { - return DocUtils::ToDoc(IdxInput(input_idx) + ".struct_info.dtype"); - } - std::string out_dtype; - if (!node()->GetAttr(key, &out_dtype) && config()->from_relay) { - return DocUtils::ToStr(node()->OutputAt(0)->DTypeName()); - } - return DocUtils::ToStr(out_dtype); -} - -const std::vector RelaxOpCode::GetAxes(const ffi::String& key) { - std::vector axes; - int axis; - if (!node()->GetAttr(key, &axes) && node()->GetAttr(key, &axis)) { - axes.push_back(axis); - } - return axes; -} - -#define RELAX_OP_CODEGEN_METHODS(TypeName) \ - public: \ - TypeName(const ffi::String& func_name) : RelaxOpCode(func_name) {} - -class RelaxAdaptivePool2dCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxAdaptivePool2dCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_list_arg("output_size") - .op_str_arg("layout") - .op_str_arg("out_layout"); - } -}; - -class RelaxAstypeCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxAstypeCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_input_arg().op_str_arg("dtype"); } -}; - -class RelaxAttentionCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxAttentionCodeGen) - - protected: - void CodeGenBuild() final { - for (size_t i = 0; i < 3; i++) { - const ffi::String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); - stack_.op_call("relax.op.permute_dims", IdxInput(i)) - .op_input_arg(i) - .op_list_arg(axes_key, "axes"); - } - stack_.op_call().op_inputs_arg(false).op_arg("scale").op_str_arg("causal_mask"); - stack_.op_call("relax.op.permute_dims").op_output_arg().op_list_arg("axes_3", "axes"); - } -}; - -class RelaxAxisCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxAxisCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes = GetAxes("axis"); - stack_.op_call().op_input_arg(); - if (axes.size() > 0) { - stack_.call_arg(axes[0], "axis"); - } - } -}; - -class RelaxAxesCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxAxesCodeGen) - - protected: - void CodeGenBuild() final { - const ffi::String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(GetAxes(key)), key); - } -}; - -class RelaxBatchMatmulCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxBatchMatmulCodeGen) - - protected: - void CodeGenBuild() final { - bool transpose_a = node()->GetTypeAttr("transpose_a"); - bool transpose_b = node()->GetTypeAttr("transpose_b"); - if (!transpose_a && !transpose_b) { - stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); - } else if (transpose_a && !transpose_b) { - std::vector axes; - for (size_t i = 0; i < node()->InputAt(0)->Ndim() - 2; i++) { - axes.push_back(i); - } - axes.push_back(node()->InputAt(0)->Ndim() - 1); - axes.push_back(node()->InputAt(0)->Ndim() - 2); - stack_.op_call("relax.op.permute_dims", IdxInput(0)) - .op_input_arg() - .call_arg(DocUtils::ToList(axes)); - BuilderEmit(IdxInput(0)); - stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); - } else if (!transpose_a && transpose_b) { - std::vector axes; - for (size_t i = 0; i < node()->InputAt(1)->Ndim() - 2; i++) { - axes.push_back(i); - } - axes.push_back(node()->InputAt(1)->Ndim() - 1); - axes.push_back(node()->InputAt(1)->Ndim() - 2); - stack_.op_call("relax.op.permute_dims", IdxInput(1)) - .op_input_arg(1) - .call_arg(DocUtils::ToList(axes)); - BuilderEmit(IdxInput(1)); - stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); - } else { - for (size_t idx = 0; idx < 2; idx++) { - std::vector axes; - for (size_t i = 0; i < node()->InputAt(idx)->Ndim() - 2; i++) { - axes.push_back(i); - } - axes.push_back(node()->InputAt(idx)->Ndim() - 1); - axes.push_back(node()->InputAt(idx)->Ndim() - 2); - stack_.op_call("relax.op.permute_dims", IdxInput(idx)) - .op_input_arg(idx) - .call_arg(DocUtils::ToList(axes)); - BuilderEmit(IdxInput(idx)); - } - stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); - } - } -}; - -class RelaxBatchNormCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxBatchNormCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_weight_arg("gamma") - .op_weight_arg("beta") - .op_weight_arg("mean") - .op_weight_arg("var") - .op_arg("axis") - .op_arg("epsilon") - .op_arg("center") - .op_arg("scale") - .op_arg("momentum"); - } -}; - -class RelaxBiasAddCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxBiasAddCodeGen) - - protected: - void CodeGenBuild() final { - int axis = CommonUtils::GetIndex(node()->GetTypeAttr("axis"), node()->OutputAt(0)->Ndim()); - ffi::Array expand_shape; - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - if (i == static_cast(axis)) { - expand_shape.push_back(node()->InputAt(0)->DimAt(i)); - } else { - expand_shape.push_back(Integer(1)); - } - } - stack_.op_call("relax.op.reshape", IdxInput(1)) - .op_input_arg(1) - .call_arg(DocUtils::ToList(expand_shape), "shape"); - BuilderEmit(IdxInput(1)); - stack_.op_call().op_inputs_arg(false); - } -}; - -class RelaxBroadcastToCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxBroadcastToCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_input_arg().op_list_arg("shape"); } -}; - -class RelaxClipCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxClipCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - if (config()->from_relay) { - stack_.op_arg("a_min", "min").op_arg("a_max", "max"); - } else { - stack_.op_arg("min").op_arg("max"); - } - } -}; - -class RelaxConcatCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxConcatCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg().op_arg("axis"); } -}; - -class RelaxConstantCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxConstantCodeGen) - - protected: - void CodeGenBuild() final { stack_.assign(IdxNode(), IdxWeight("const")); } -}; - -class RelaxConvCodeGen : public RelaxOpCode { - public: - RelaxConvCodeGen(const ffi::String& func_name, bool use_bias) - : RelaxOpCode(func_name), use_bias_(use_bias) {} - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_weight_arg("weight") - .op_list_arg("strides") - .op_list_arg("padding") - .op_list_arg("dilation") - .op_arg("groups") - .op_str_arg("data_layout") - .op_str_arg("kernel_layout") - .op_str_arg("out_layout") - .call_arg(GetOutDtype(), "out_dtype"); - if (use_bias_) { - std::string out_layout_str; - if (!node()->GetAttr("out_layout", &out_layout_str)) { - TVM_FFI_ICHECK(node()->GetAttr("data_layout", &out_layout_str)) - << "out_layout or data_layout should be given, get " << node(); - } - const auto& out_layout = tir::Layout(out_layout_str); - ffi::Array expand_shape; - for (size_t i = 0; i < node()->OutputAt(0)->Ndim(); i++) { - if (out_layout[i].name() == "C") { - expand_shape.push_back(node()->OutputAt(0)->DimAt(i)); - } else { - expand_shape.push_back(Integer(1)); - } - } - BuilderEmit(IdxNode()); - stack_.func_call("relax.op.reshape", "expand_bias") - .op_weight_arg("bias") - .call_arg(DocUtils::ToList(expand_shape), "shape"); - BuilderEmit("expand_bias"); - stack_.func_call("relax.op.add", IdxNode()).call_arg(IdxNode()).call_arg("expand_bias"); - } - } - - private: - bool use_bias_; -}; - -class RelaxCreateCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxCreateCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_list_arg("shape").op_str_arg("dtype"); } -}; - -class RelaxCreateLikeCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxCreateLikeCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_input_arg().op_str_arg("dtype"); } -}; - -class RelaxCumsumCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxCumsumCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_arg("axis").op_str_arg("dtype"); - } -}; - -class RelaxEinsumCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxEinsumCodeGen) - - protected: - void CodeGenBuild() final { - const ffi::String& key = config()->from_relay ? "equation" : "subscripts"; - stack_.op_call().op_inputs_arg().op_str_arg(key, "subscripts"); - } -}; - -class RelaxStridedSliceCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxStridedSliceCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { - axes.push_back(i); - } - } - stack_.op_call() - .op_input_arg() - .call_arg(DocUtils::ToList(axes), "axes") - .op_list_arg("begin") - .op_list_arg("end") - .op_list_arg("strides"); - } -}; - -class RelaxEmbeddingCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxEmbeddingCodeGen) - - protected: - void CodeGenBuild() final { - const auto& input = node()->InputAt(0); - if (input->DTypeName() != "int32") { - stack_.op_call("relax.op.astype", IdxInput()) - .op_input_arg() - .call_arg(DocUtils::ToStr("int32")); - BuilderEmit(IdxInput()); - } - if (input->Ndim() > 1) { - stack_.op_call("relax.op.reshape", IdxInput()) - .op_input_arg() - .call_arg(DocUtils::ToList(std::vector{-1}), "shape"); - BuilderEmit(IdxInput()); - } - stack_.op_call().op_weight_arg("weight").op_input_arg().op_arg("axis"); - if (input->Ndim() > 1) { - BuilderEmit(IdxNode()); - stack_.op_call("relax.op.reshape") - .op_output_arg() - .call_arg(DocUtils::ToList(node()->OutputAt(0)->shape)); - } - } -}; - -class RelaxFullCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxFullCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_list_arg("shape").op_input_arg(0, "fill_value").op_str_arg("dtype"); - } -}; - -class RelaxGetItemCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxGetItemCodeGen) - - protected: - void CodeGenBuild() final { - const auto& producer = node()->ProducerOf(0); - stack_.op_call("msc::auto", IdxNode()).call_arg(IdxNodeBase(producer)).op_arg("index"); - } -}; - -class RelaxGroupNormCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxGroupNormCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta").op_arg( - "num_groups"); - if (config()->from_relay) { - std::vector axes; - for (size_t i = 2; i < node()->InputAt(0)->Ndim(); i++) { - axes.push_back(i); - } - stack_.op_arg("axis", "channel_axis").call_arg(DocUtils::ToList(axes), "axes"); - } else { - stack_.op_arg("channel_axis").op_list_arg("axes"); - } - stack_.op_arg("epsilon").op_arg("center").op_arg("scale"); - } -}; - -class RelaxLayerNormCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxLayerNormCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta"); - if (config()->from_relay) { - stack_.op_arg("axis", "axes"); - } else { - stack_.op_list_arg("axes"); - } - stack_.op_arg("epsilon").op_arg("center").op_arg("scale"); - } -}; - -class RelaxLinearCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxLinearCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call(); - if (node()->inputs.size() == 1) { - stack_.op_input_arg().op_weight_arg("weight").op_weight_arg("bias"); - } else { - stack_.op_inputs_arg(false); - } - stack_.call_arg(GetOutDtype(), "out_dtype"); - } -}; - -class RelaxMatmulCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxMatmulCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg(false).call_arg(GetOutDtype(), "out_dtype"); - } -}; - -class RelaxNllLossCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxNllLossCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg(false).op_str_arg("reduction").op_arg("ignore_index"); - } -}; - -class RelaxPadCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxPadCodeGen) - - protected: - void CodeGenBuild() final { - ffi::Array pad_width; - const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); - TVM_FFI_ICHECK(attr_pad_width.size() % 2 == 0) - << "pad_width should be multiple of 2, get " << node(); - for (size_t i = 0; i < attr_pad_width.size(); i += 2) { - const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + - std::to_string(attr_pad_width[i + 1]) + "]"; - pad_width.push_back(cur_pad); - } - stack_.op_call() - .op_input_arg() - .op_list_arg("pad_width") - .op_input_arg(1, "pad_value") - .op_str_arg("pad_mode"); - } -}; - -class RelaxPool2dCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxPool2dCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call() - .op_input_arg() - .op_list_arg("pool_size") - .op_list_arg("strides") - .op_list_arg("padding") - .op_list_arg("dilation") - .op_arg("count_include_pad") - .op_arg("ceil_mode") - .op_str_arg("layout") - .op_str_arg("out_layout"); - } -}; - -class RelaxPermuteDimsCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxPermuteDimsCodeGen) - - protected: - void CodeGenBuild() final { - std::vector axes; - if (!node()->GetAttr("axes", &axes)) { - for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { - axes.push_back(i - 1); - } - } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes), "axes"); - } -}; - -class RelaxReduceAxisCodeGen : public RelaxOpCode { - public: - RelaxReduceAxisCodeGen(const ffi::String& func_name, bool as_list) - : RelaxOpCode(func_name), as_list_(as_list) {} - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - std::vector axes = GetAxes("axis"); - if (as_list_) { - stack_.call_arg(DocUtils::ToList(axes), "axis"); - } else if (axes.size() > 0) { - stack_.call_arg(axes[0], "axis"); - } - stack_.op_arg("keepdims"); - } - - private: - bool as_list_; -}; - -class RelaxRepeatCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxRepeatCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg().op_arg("repeats").op_arg("axis"); - } -}; - -class RelaxReshapeCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxReshapeCodeGen) - - protected: - void CodeGenBuild() final { - const auto& out_shape = GetPrims(node()->OutputAt(0)); - stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(out_shape), "shape"); - } -}; - -class RelaxScatterElementsCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxScatterElementsCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg("axis"); } -}; - -class RelaxScatterNDCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxScatterNDCodeGen) - - protected: - void CodeGenBuild() final { - if (config()->from_relay) { - size_t ndim = node()->InputAt(1)->Ndim(); - std::vector axes; - axes.push_back(ndim - 1); - for (size_t i = 0; i < ndim - 1; i++) { - axes.push_back(i); - } - stack_.func_call("relax.op.permute_dims", IdxInput(1)) - .call_arg(IdxInput(1)) - .call_arg(DocUtils::ToList(axes)); - BuilderEmit(IdxInput(1), "permute_" + std::to_string(node()->index)); - } - stack_.op_call().op_inputs_arg(false).op_str_arg("mode", "reduction"); - } -}; - -class RelaxResize2dCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxResize2dCodeGen) - - protected: - void CodeGenBuild() final { - // roi has forced to be float list - ffi::Array roi_list; - std::vector roi = node()->GetTypeArrayAttr("roi"); - for (const auto& r : roi) { - roi_list.push_back("float(" + std::to_string(r) + ")"); - } - stack_.op_call() - .op_input_arg() - .func_call("relax.ShapeExpr") - .op_list_arg("size", "values") - .pop_nest() - .call_arg(DocUtils::ToList(roi_list)) - .op_str_arg("layout") - .op_str_arg("method") - .op_str_arg("coordinate_transformation_mode") - .op_str_arg("rounding_method") - .op_arg("cubic_alpha") - .op_arg("cubic_exclude") - .op_arg("extrapolation_value") - .op_str_arg("out_dtype"); - } -}; - -class RelaxShapeCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxShapeCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_list_arg("shape", "values"); } -}; - -class RelaxSimpleCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxSimpleCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false); } -}; - -class RelaxSplitCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxSplitCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - int sections; - if (node()->GetAttr("indices_or_sections", §ions)) { - stack_.op_arg("indices_or_sections"); - } else { - stack_.op_list_arg("indices_or_sections"); - } - stack_.op_arg("axis"); - } -}; - -class RelaxStackCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxStackCodeGen) - - protected: - void CodeGenBuild() final { - stack_.op_call().op_inputs_arg().op_arg("axis"); - BuilderEmit(IdxNode(), "cat_" + std::to_string(node()->index)); - const auto& out_shape = GetPrims(node()->OutputAt(0)); - stack_.func_call("relax.op.reshape", IdxNode()) - .call_arg(IdxNode()) - .call_arg(DocUtils::ToList(out_shape), "shape"); - } -}; - -class RelaxTakeCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg("axis"); } -}; - -class RelaxTileCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxTileCodeGen) - - protected: - void CodeGenBuild() final { - const ffi::String& key = config()->from_relay ? "reps" : "repeats"; - stack_.op_call().op_input_arg().op_list_arg(key, "repeats"); - } -}; - -class RelaxTupleCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxTupleCodeGen) - - protected: - void CodeGenBuild() final { stack_.op_call().op_inputs_arg(); } -}; - -class RelaxTriCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxTriCodeGen) - - protected: - void CodeGenBuild() final { - if (node()->optype == "trilu") { - const ffi::String& func_name = - node()->GetTypeAttr("upper") ? "relax.op.triu" : "relax.op.tril"; - stack_.op_call(func_name).op_input_arg().op_arg("k"); - } else { - stack_.op_call().op_input_arg().op_arg("k"); - } - } -}; - -class RelaxPluginOpCodeGen : public RelaxOpCode { - RELAX_OP_CODEGEN_METHODS(RelaxPluginOpCodeGen) - - protected: - void CodeGenBuild() final { - const auto& plugin = GetPlugin(node()->optype); - stack_.op_call("plugin." + node()->optype).op_inputs_arg(false); - for (const auto& a : plugin->attrs) { - stack_.call_arg(GetAttrDoc(a->name, a->type), a->name); - } - } -}; - -const std::shared_ptr>> -GetRelaxOpCodes() { - static auto map = - std::make_shared>>(); - if (!map->empty()) return map; - // binary && unary ops - map->emplace("abs", std::make_shared("relax.op.abs")); - map->emplace("acos", std::make_shared("relax.op.acos")); - map->emplace("acosh", std::make_shared("relax.op.acosh")); - map->emplace("add", std::make_shared("relax.op.add")); - map->emplace("asin", std::make_shared("relax.op.asin")); - map->emplace("asinh", std::make_shared("relax.op.asinh")); - map->emplace("atan", std::make_shared("relax.op.atan")); - map->emplace("atanh", std::make_shared("relax.op.atanh")); - map->emplace("bitwise_and", std::make_shared("relax.op.bitwise_and")); - map->emplace("bitwise_not", std::make_shared("relax.op.bitwise_not")); - map->emplace("bitwise_or", std::make_shared("relax.op.bitwise_or")); - map->emplace("bitwise_xor", std::make_shared("relax.op.bitwise_xor")); - map->emplace("ceil", std::make_shared("relax.op.ceil")); - map->emplace("cos", std::make_shared("relax.op.cos")); - map->emplace("cosh", std::make_shared("relax.op.cosh")); - map->emplace("divide", std::make_shared("relax.op.divide")); - map->emplace("equal", std::make_shared("relax.op.equal")); - map->emplace("erf", std::make_shared("relax.op.erf")); - map->emplace("exp", std::make_shared("relax.op.exp")); - map->emplace("floor", std::make_shared("relax.op.floor")); - map->emplace("floor_divide", std::make_shared("relax.op.floor_divide")); - map->emplace("greater", std::make_shared("relax.op.greater")); - map->emplace("greater_equal", std::make_shared("relax.op.greater_equal")); - map->emplace("isfinite", std::make_shared("relax.op.isfinite")); - map->emplace("isinf", std::make_shared("relax.op.isinf")); - map->emplace("isnan", std::make_shared("relax.op.isnan")); - map->emplace("less", std::make_shared("relax.op.less")); - map->emplace("less_equal", std::make_shared("relax.op.less_equal")); - map->emplace("log", std::make_shared("relax.op.log")); - map->emplace("logical_and", std::make_shared("relax.op.logical_and")); - map->emplace("logical_or", std::make_shared("relax.op.logical_or")); - map->emplace("logical_xor", std::make_shared("relax.op.logical_xor")); - map->emplace("logical_not", std::make_shared("relax.op.logical_not")); - map->emplace("maximum", std::make_shared("relax.op.maximum")); - map->emplace("minimum", std::make_shared("relax.op.minimum")); - map->emplace("multiply", std::make_shared("relax.op.multiply")); - map->emplace("negative", std::make_shared("relax.op.negative")); - map->emplace("not_equal", std::make_shared("relax.op.not_equal")); - map->emplace("power", std::make_shared("relax.op.power")); - map->emplace("round", std::make_shared("relax.op.round")); - map->emplace("rsqrt", std::make_shared("relax.op.rsqrt")); - map->emplace("sigmoid", std::make_shared("relax.op.sigmoid")); - map->emplace("sign", std::make_shared("relax.op.sign")); - map->emplace("sin", std::make_shared("relax.op.sin")); - map->emplace("sinh", std::make_shared("relax.op.sinh")); - map->emplace("square", std::make_shared("relax.op.square")); - map->emplace("sqrt", std::make_shared("relax.op.sqrt")); - map->emplace("subtract", std::make_shared("relax.op.subtract")); - map->emplace("tan", std::make_shared("relax.op.tan")); - map->emplace("tanh", std::make_shared("relax.op.tanh")); - map->emplace("where", std::make_shared("relax.op.where")); - - // reduce axis ops - map->emplace("argmax", std::make_shared("relax.op.argmax", false)); - map->emplace("argmin", std::make_shared("relax.op.argmin", false)); - map->emplace("max", std::make_shared("relax.op.max", true)); - map->emplace("min", std::make_shared("relax.op.min", true)); - map->emplace("mean", std::make_shared("relax.op.mean", true)); - map->emplace("sum", std::make_shared("relax.op.sum", true)); - map->emplace("prod", std::make_shared("relax.op.prod", true)); - map->emplace("std", std::make_shared("relax.op.std", true)); - - // axis && axes ops - map->emplace("nn.log_softmax", std::make_shared("relax.op.nn.log_softmax")); - map->emplace("nn.softmax", std::make_shared("relax.op.nn.softmax")); - map->emplace("expand_dims", std::make_shared("relax.op.expand_dims")); - map->emplace("squeeze", std::make_shared("relax.op.squeeze")); - - // math ops - map->emplace("astype", std::make_shared("relax.op.astype")); - map->emplace("broadcast_to", std::make_shared("relax.op.broadcast_to")); - map->emplace("cast", std::make_shared("relax.op.astype")); - map->emplace("clip", std::make_shared("relax.op.clip")); - map->emplace("concat", std::make_shared("relax.op.concat")); - map->emplace("concatenate", std::make_shared("relax.op.concat")); - map->emplace("cumsum", std::make_shared("relax.op.cumsum")); - map->emplace("einsum", std::make_shared("relax.op.einsum")); - map->emplace("matmul", std::make_shared("relax.op.linear_algebra.matmul")); - map->emplace("permute_dims", std::make_shared("relax.op.permute_dims")); - map->emplace("repeat", std::make_shared("relax.op.repeat")); - map->emplace("reshape", std::make_shared("relax.op.reshape")); - map->emplace("scatter_elements", - std::make_shared("relax.op.scatter_elements")); - map->emplace("scatter_nd", std::make_shared("relax.op.scatter_nd")); - map->emplace("split", std::make_shared("relax.op.split")); - map->emplace("stack", std::make_shared("relax.op.concat")); - map->emplace("strided_slice", - std::make_shared("relax.op.strided_slice")); - map->emplace("take", std::make_shared("relax.op.take")); - map->emplace("tile", std::make_shared("relax.op.tile")); - map->emplace("transpose", std::make_shared("relax.op.permute_dims")); - - // create ops - map->emplace("constant", std::make_shared("relax.Var")); - map->emplace("full", std::make_shared("relax.op.full")); - map->emplace("ones", std::make_shared("relax.op.ones")); - map->emplace("ones_like", std::make_shared("relax.op.ones_like")); - map->emplace("tril", std::make_shared("relax.op.tril")); - map->emplace("triu", std::make_shared("relax.op.triu")); - map->emplace("trilu", std::make_shared("")); - map->emplace("zeros", std::make_shared("relax.op.zeros")); - map->emplace("zeros_like", std::make_shared("relax.op.zeros_like")); - - // nn ops - map->emplace("nn.adaptive_avg_pool2d", - std::make_shared("relax.op.nn.adaptive_avg_pool2d")); - map->emplace("nn.avg_pool2d", std::make_shared("relax.op.nn.avg_pool2d")); - map->emplace("nn.batch_matmul", - std::make_shared("relax.op.linear_algebra.matmul")); - map->emplace("nn.batch_norm", std::make_shared("relax.op.nn.batch_norm")); - map->emplace("nn.bias_add", std::make_shared("relax.op.add")); - map->emplace("nn.conv1d", std::make_shared("relax.op.nn.conv1d", false)); - map->emplace("nn.conv2d", std::make_shared("relax.op.nn.conv2d", false)); - map->emplace("nn.dense", std::make_shared("relax.op.linear_algebra.linear")); - map->emplace("nn.gelu", std::make_shared("relax.op.nn.gelu")); - map->emplace("nn.group_norm", std::make_shared("relax.op.nn.group_norm")); - map->emplace("nn.layer_norm", std::make_shared("relax.op.nn.layer_norm")); - map->emplace("nn.max_pool2d", std::make_shared("relax.op.nn.max_pool2d")); - map->emplace("nn.nll_loss", std::make_shared("relax.op.nn.nll_loss")); - map->emplace("nn.pad", std::make_shared("relax.op.nn.pad")); - map->emplace("nn.relu", std::make_shared("relax.op.nn.relu")); - map->emplace("nn.silu", std::make_shared("relax.op.nn.silu")); - - // image ops - map->emplace("image.resize2d", std::make_shared("relax.op.image.resize2d")); - - // special op - map->emplace("get_item", std::make_shared("relax.TupleGetItem")); - map->emplace("shape", std::make_shared("relax.ShapeExpr")); - map->emplace("tuple", std::make_shared("relax.Tuple")); - map->emplace("plugin", std::make_shared("Plugin")); - - // msc ops - map->emplace("msc.attention", std::make_shared("relax.op.nn.attention")); - map->emplace("msc.conv1d_bias", std::make_shared("relax.op.nn.conv1d", true)); - map->emplace("msc.conv2d_bias", std::make_shared("relax.op.nn.conv2d", true)); - map->emplace("msc.embedding", std::make_shared("relax.op.take")); - map->emplace("msc.gelu", std::make_shared("relax.op.nn.gelu")); - map->emplace("msc.linear", - std::make_shared("relax.op.linear_algebra.linear")); - map->emplace("msc.linear_bias", - std::make_shared("relax.op.linear_algebra.linear")); - map->emplace("msc.matmul", - std::make_shared("relax.op.linear_algebra.matmul")); - - return map; -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/framework/tvm/relax_opcode.h b/src/contrib/msc/framework/tvm/relax_opcode.h deleted file mode 100644 index bbbee44d822d..000000000000 --- a/src/contrib/msc/framework/tvm/relax_opcode.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/framework/tvm/relax_opcode.h - * \brief Relax codegen for MSCJoint. - */ -#ifndef TVM_CONTRIB_MSC_FRAMEWORK_TVM_RELAX_OPCODE_H_ -#define TVM_CONTRIB_MSC_FRAMEWORK_TVM_RELAX_OPCODE_H_ - -#include -#include -#include -#include - -#include "../../core/codegen/base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -class RelaxOpCode; -typedef OpCodeStack RelaxOpCodeStack; - -/*! - * \brief CodeGen for relax op - */ -class RelaxOpCode : public BaseOpCode { - public: - /*! - * \brief The constructor of BaseOpDocsifier - * \param func_name the function name for the node. - * \param config the config json for the node. - */ - explicit RelaxOpCode(const ffi::String& func_name) - : BaseOpCode(func_name) {} - - /*! \brief Convert node to docs*/ - const ffi::Array GetDocs() final; - - protected: - RelaxOpCodeStack stack_; - - /*! \brief Convert op build*/ - virtual void CodeGenBuild() = 0; - - /*! \brief coda stack emit docs*/ - void BuilderEmit(const ffi::String& ret, const ffi::String& name = ""); - - /*! \brief Get the out_dtype attribute*/ - const ExprDoc GetOutDtype(const ffi::String& key = "out_dtype", int input_idx = 0); - - /*! \brief Get the axes attribute*/ - const std::vector GetAxes(const ffi::String& key = "axes"); -}; - -/*! - * \brief Get the map of available RelaxOpCode, use optype as key - * \return Map of - */ -const std::shared_ptr>> -GetRelaxOpCodes(); - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_FRAMEWORK_TVM_RELAX_OPCODE_H_ diff --git a/src/contrib/msc/plugin/base_codegen.h b/src/contrib/msc/plugin/base_codegen.h deleted file mode 100644 index f8d63360c76a..000000000000 --- a/src/contrib/msc/plugin/base_codegen.h +++ /dev/null @@ -1,679 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/base_codegen.h - * \brief The codegen for Plugin. - */ -#ifndef TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ -#define TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "../core/codegen/code_stack.h" -#include "../core/ir/plugin.h" -#include "../core/printer/cpp_printer.h" -#include "../core/printer/python_printer.h" - -namespace tvm { -namespace contrib { -namespace msc { - -using namespace tvm::script::printer; - -/*! - * \brief CodeGen for Plugin - */ -template -class BasePluginCodeGen { - public: - /*! - * \brief The constructor of BasePluginCodeGen - * \param config the options for codegen. - */ - explicit BasePluginCodeGen(const std::string& config = "") { - config_.reset(new ConfigType()); - if (config.size() > 0) { - namespace json = ::tvm::ffi::json; - config_->Load(json::Parse(config).cast()); - } - } - - virtual ~BasePluginCodeGen() = default; - - /*! \brief Get plugin sources*/ - virtual const ffi::Map GetBuildSources( - const std::string& print_options = "") { - ffi::Map sources; - // plugin sources - for (const auto& name : ListPluginNames()) { - const auto& plugin = GetPlugin(name); - // attr declare - const ffi::String& attr_macro = - "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_ATTR_H_"; - this->stack_.line("#ifndef " + attr_macro) - .line("#define " + attr_macro) - .line() - .line("#include \"plugin_utils.h\"") - .line(); - StartNamespace(); - CodeGenAttrDeclare(plugin); - EndNamespace(); - this->stack_.line("#endif // " + attr_macro); - sources.Set(plugin->name + "_attr.h", ToCppSource(print_options)); - // attr define - this->stack_.line("#include \"" + plugin->name + "_attr.h\"").line(); - StartNamespace(); - CodeGenAttrDefine(plugin); - EndNamespace(); - sources.Set(plugin->name + "_attr.cc", ToCppSource(print_options)); - // op decalre - const ffi::String& op_macro = - "TVM_CONTRIB_MSC_" + StringUtils::Upper(plugin->name) + "_OP_H_"; - this->stack_.line("#ifndef " + op_macro).line("#define " + op_macro).line(); - CodeGenOpHeader(plugin); - StartNamespace(); - CodeGenOpDeclare(plugin); - EndNamespace(); - this->stack_.line("#endif // " + op_macro); - sources.Set(plugin->name + "_op.h", ToCppSource(print_options)); - // op define - this->stack_.line("#include \"" + plugin->name + "_op.h\"").line(); - StartNamespace(); - CodeGenOpDefine(plugin); - EndNamespace(); - sources.Set(plugin->name + "_op.cc", ToCppSource(print_options)); - // op runtime - if (this->config()->with_runtime) { - CodeGenOpHeader(plugin); - StartNamespace(); - CodeGenOpRuntime(plugin); - EndNamespace(); - sources.Set(plugin->name + "_runtime.cc", ToCppSource(print_options)); - } - } - // cmakelists - std::set devices; - for (const auto& name : ListPluginNames()) { - const auto& plugin = GetPlugin(name); - for (const auto& pair : plugin->externs) { - if (StringUtils::EndsWith(pair.first, "_compute")) { - devices.insert(StringUtils::Replace(pair.first, "_compute", "")); - } - } - } - CodeGenCmake(devices); - sources.Set("CMakeLists.txt", ToCppSource(print_options)); - return sources; - } - - /*! \brief Get manager sources*/ - virtual const ffi::Map GetManagerSources( - const std::string& print_options = "") { - ffi::Map sources; - CodeGenManagerDepends(); - this->stack_.class_def("PluginManager(object)").class_start(); - CodeGenManagerMethods(); - for (const auto& name : ListPluginNames()) { - CodeGenOpBuilder(GetPlugin(name)); - } - if (this->config()->need_convert) { - ffi::Map symbols; - this->stack_.func_def("get_convert_map") - .func_decorator("classmethod") - .func_arg("cls", "object") - .func_start(); - CodeGenConvertDepends(); - for (const auto& name : ListPluginNames()) { - const auto& plugin = GetPlugin(name); - const auto& symbol = CodeGenOpConvert(plugin); - symbols.Set(plugin, symbol); - } - this->stack_.assign("converters", "{}"); - for (const auto& pair : symbols) { - this->stack_.assign(DocUtils::ToIndex("converters", DocUtils::ToStr(pair.second)), - ConverterName(pair.first)); - } - this->stack_.func_end("converters"); - } - this->stack_.class_end(); - sources.Set("manager.py", ToPySource(print_options)); - return sources; - } - - protected: - /*! \brief Header of plugin files*/ - virtual void CodeGenOpHeader(const Plugin& plugin) { - this->stack_.line("#include \"" + plugin->name + "_attr.h\""); - std::set include_headers; - for (const auto& pair : plugin->externs) { - if (pair.second->header.size() > 0 && !include_headers.count(pair.second->header)) { - this->stack_.line("#include \"" + pair.second->header + "\""); - include_headers.insert(pair.second->header); - } - } - this->stack_.line(); - } - - /*! \brief Start the namespace*/ - void StartNamespace() { - this->stack_.line("namespace tvm {") - .line("namespace contrib {") - .line("namespace msc {") - .line("namespace plugin {") - .line(); - } - - /*! \brief End the namespace*/ - void EndNamespace() { - this->stack_.line("} // namespace plugin") - .line("} // namespace msc") - .line("} // namespace contrib") - .line("} // namespace tvm"); - } - - /*! \brief Codegen safe call extern*/ - void CodeGenSafeCall(const PluginExtern& extern_func, - const ffi::Array& call_args = ffi::Array(), - const ffi::String& ret = "") { - this->stack_.scope_start("try {").func_call(extern_func->name, ret); - for (const auto& arg : call_args) { - this->stack_.call_arg(arg); - } - this->stack_.scope_end() - .scope_start("} catch (const std::exception& exc) {") - .line("std::cerr << \"Failed to run extern " + extern_func->name + - " : \" << exc.what() << std::endl;") - .line("throw std::runtime_error(\"Failed to run extern " + extern_func->name + "\");") - .scope_end() - .line("}"); - } - - /*! \brief Codegen plugin attr declare*/ - virtual void CodeGenAttrDeclare(const Plugin& plugin) { - this->stack_.struct_start(MetaAttrCls(plugin)).comment("define attributes"); - for (const auto& attr : plugin->attrs) { - this->stack_.declare(ToCppType(attr->type), attr->name); - if (attr->default_value.size() > 0) { - this->stack_.declare_arg(attr->default_value); - } - } - this->stack_.line() - .comment("print method") - .func_def("operator<<", "friend std::ostream&") - .func_arg("out", "std::ostream&") - .func_arg("attrs", "const " + MetaAttrCls(plugin) + "&") - .func_start() - .line("out << \"[" + MetaAttrCls(plugin) + "] : \";"); - for (const auto& attr : plugin->attrs) { - this->stack_.line("out << \"| " + attr->name + "(" + attr->type + ")=\" << attrs." + - attr->name + ";"); - } - this->stack_.func_end("out").struct_end(); - } - - /*! \brief Codegen plugin attr define*/ - virtual void CodeGenAttrDefine(const Plugin& plugin) {} - - /*! \brief Codegen plugin op declare*/ - virtual void CodeGenOpDeclare(const Plugin& plugin) = 0; - - /*! \brief Codegen plugin op define*/ - virtual void CodeGenOpDefine(const Plugin& plugin) = 0; - - /*! \brief Codegen plugin runtime*/ - virtual void CodeGenOpRuntime(const Plugin& plugin) {} - - /*! \brief Codegen CMake file*/ - virtual void CodeGenCmake(const std::set& devices) { - CodeGenPreCmake(devices); - CodeGenPostCmake(devices); - } - - /*! \brief Codegen CMake start*/ - void CodeGenPreCmake(const std::set& devices, - const ffi::Map& extra_flags = - ffi::Map()) { - const auto& p_name = this->config()->project_name; - stack_.line("cmake_minimum_required(VERSION " + this->config()->cmake_version + " FATAL_ERROR)") - .line("project(" + p_name + ")"); - if (devices.count("cuda")) { - stack_.line("find_package(CUDA)").line("add_definitions(-DPLUGIN_ENABLE_CUDA)"); - } - stack_.line(); - for (const auto& pair : extra_flags) { - if (pair.second.size() == 0) { - stack_.line("add_definitions(-D" + pair.first + ")"); - } else { - stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")"); - } - } - for (const auto& pair : this->config()->flags) { - if (pair.second.size() == 0) { - stack_.line("add_definitions(-D" + pair.first + ")"); - } else { - stack_.line("add_definitions(-D" + pair.first + "=" + pair.second + ")"); - } - } - stack_.line(); - } - - /*! \brief Codegen CMake end*/ - void CodeGenPostCmake(const std::set& devices, - const ffi::Array& extra_includes = ffi::Array(), - const ffi::Array& extra_libs = ffi::Array()) { - const auto& p_name = this->config()->project_name; - stack_.line() - .line("file(GLOB_RECURSE PLUGIN_HEADERS src/*.h)") - .line("file(GLOB_RECURSE PLUGIN_CC_SRCS src/*.cc)"); - if (devices.count("cuda")) { - stack_.line("file(GLOB_RECURSE PLUGIN_CU_SRCS src/*.cu)"); - } - if (devices.count("cuda")) { - stack_.line("cuda_add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS} ${PLUGIN_CU_SRCS})"); - } else { - stack_.line("add_library(" + p_name + " SHARED ${PLUGIN_CC_SRCS})"); - } - // define includes - ffi::String includes = StringUtils::Join(extra_includes, " "); - if (this->config()->includes.size() > 0) { - includes = includes + " " + StringUtils::Join(this->config()->includes, " "); - } - if (includes.size() > 0) { - stack_.line("target_include_directories(" + p_name + " PUBLIC " + includes + ")"); - } - // define libs - ffi::String link_libs = StringUtils::Join(extra_libs, " "); - const auto& libs = StringUtils::Join(this->config()->libs, " "); - if (libs.size() > 0) { - link_libs = link_libs + " " + libs; - } - if (link_libs.size() > 0) { - stack_.line("target_link_libraries(" + p_name + " " + link_libs + ")"); - } - const auto& install_dir = this->config()->install_dir; - if (install_dir.size() > 0) { - stack_.line() - .line("SET(LIBRARY_OUTPUT_PATH " + install_dir + "/lib)") - .line("file(COPY ${PLUGIN_HEADERS} DESTINATION " + install_dir + "/include)"); - if (this->config()->libs.size() > 0) { - stack_.line("file(COPY " + libs + " DESTINATION " + install_dir + "/lib)"); - } - } - } - - /*! \brief Codegen manager depends*/ - virtual void CodeGenManagerDepends() { - this->stack_.line("import os") - .line("import shutil") - .line("import ctypes") - .line("from typing import Any, List, Dict") - .line(); - } - - /*! \brief Codegen manager methods*/ - virtual void CodeGenManagerMethods() { - // init method - stack_.func_def("__init__") - .func_arg("self", "object") - .func_arg("root", "str", "None") - .func_start() - .cond_if("root is None") - .assign("root", "os.path.dirname(__name__)") - .cond_end() - .assign(DocUtils::ToAttrAccess("self", "_lib_folder"), "os.path.join(root, \"lib\")") - .func_call("assert") - .inplace_start("os.path.isdir") - .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) - .inplace_end() - .assign(DocUtils::ToAttrAccess("self", "_include_folder"), - "os.path.join(root, \"include\")") - .func_call("assert") - .inplace_start("os.path.isdir") - .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) - .inplace_end() - .assign(DocUtils::ToAttrAccess("self", "_manager_file"), - "os.path.join(root, \"manager.py\")") - .func_call("assert") - .inplace_start("os.path.isfile") - .call_arg(DocUtils::ToAttrAccess("self", "_manager_file")) - .inplace_end() - .func_call("setup", "", "self") - .func_end(); - // list headers - this->stack_.func_def("list_includes") - .func_arg("self", "object") - .func_arg("as_abs", "bool", "False") - .func_start() - .assign("includes", "[]") - .for_start("f", "os.listdir(self._include_folder)") - .cond_if("as_abs") - .func_call("append", "", "includes") - .inplace_start("os.path.join") - .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) - .call_arg("f") - .inplace_end() - .cond_else() - .func_call("append", "", "includes") - .call_arg("f") - .cond_end() - .for_end() - .func_end("includes"); - // copy the headers - this->stack_.func_def("copy_includes") - .func_arg("self", "object") - .func_arg("dst", "str") - .func_start() - .cond_if("not os.path.isdir(dst)") - .func_call("makedirs", "", "os") - .call_arg("dst") - .cond_end() - .for_start("header", "os.listdir(self._include_folder)") - .func_call("shutil.copyfile") - .inplace_start("os.path.join") - .call_arg(DocUtils::ToAttrAccess("self", "_include_folder")) - .call_arg("header") - .inplace_end() - .inplace_start("os.path.join") - .call_arg("dst") - .call_arg("header") - .inplace_end() - .for_end() - .func_end(); - // list libs - this->stack_.func_def("list_libs") - .func_arg("self", "object") - .func_arg("as_abs", "bool", "False") - .func_start() - .assign("libs", "[]") - .for_start("f", "os.listdir(self._lib_folder)") - .cond_if("as_abs") - .func_call("append", "", "libs") - .inplace_start("os.path.join") - .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) - .call_arg("f") - .inplace_end() - .cond_else() - .func_call("append", "", "libs") - .call_arg("f") - .cond_end() - .for_end() - .func_end("libs"); - // copy the libs - this->stack_.func_def("copy_libs") - .func_arg("self", "object") - .func_arg("dst", "str") - .func_start() - .cond_if("not os.path.isdir(dst)") - .func_call("makedirs", "", "os") - .call_arg("dst") - .cond_end() - .for_start("lib", "os.listdir(self._lib_folder)") - .func_call("shutil.copyfile") - .inplace_start("os.path.join") - .call_arg(DocUtils::ToAttrAccess("self", "_lib_folder")) - .call_arg("lib") - .inplace_end() - .inplace_start("os.path.join") - .call_arg("dst") - .call_arg("lib") - .inplace_end() - .for_end() - .func_end(); - // export method - this->stack_.func_def("export") - .func_arg("self", "object") - .func_arg("dst", "str") - .func_start() - .func_call("copy_includes", "", "self") - .inplace_start("os.path.join") - .call_arg("dst") - .call_arg(DocUtils::ToStr("include")) - .inplace_end() - .func_call("copy_libs", "", "self") - .inplace_start("os.path.join") - .call_arg("dst") - .call_arg(DocUtils::ToStr("lib")) - .inplace_end() - .func_call("shutil.copyfile") - .call_arg(DocUtils::ToAttrAccess("self", "_manager_file")) - .inplace_start("os.path.join") - .call_arg("dst") - .call_arg(DocUtils::ToStr("manager.py")) - .inplace_end() - .func_end(); - // get op names - this->stack_.func_def("get_op_names", "List[str]") - .func_arg("self", "object") - .func_start() - .assign("names", "[]"); - for (const auto& name : ListPluginNames()) { - this->stack_.func_call("append", "", "names").call_arg(DocUtils::ToStr(name)); - } - this->stack_.func_end("names"); - // get ops info - this->stack_.func_def("get_ops_info", "dict") - .func_arg("self", "object") - .func_start() - .assign("info", "{}"); - for (const auto& name : ListPluginNames()) { - TVM_FFI_ICHECK(this->config()->ops_info.count(name)) << "Can not find op info for " << name; - const auto& info = this->config()->ops_info[name]; - this->stack_.assign(DocUtils::ToIndex("info", DocUtils::ToStr(name)), info); - } - this->stack_.func_end("info"); - } - - /*! \brief Codegen manager for plugin*/ - virtual void CodeGenOpBuilder(const Plugin& plugin) {} - - /*! \brief Codegen convert depends*/ - virtual void CodeGenConvertDepends() { - this->stack_.line("from tvm import relax") - .line("from tvm.relax import call_dps_packed") - .line("from tvm.contrib.msc.plugin import utils as plugin_utils") - .line("from tvm.contrib.msc.plugin.op import _ffi_api as _plugin_api") - .line("from tvm.contrib.msc.core import utils as msc_utils") - .line(); - } - - /*! \brief Codegen convert function for plugin*/ - virtual const ffi::String CodeGenOpConvert(const Plugin& plugin) { return plugin->name; } - - /*! \brief Change code stack to cpp source*/ - const ffi::String ToCppSource(const std::string& print_options = "") { - CppPrinter printer(print_options); - for (const auto& d : this->stack_.GetDocs()) { - printer.Append(d); - } - this->stack_.Reset(); - return printer.GetString(); - } - - /*! \brief Change code stack to python source*/ - const ffi::String ToPySource(const std::string& print_options = "") { - PythonPrinter printer(print_options); - for (const auto& d : this->stack_.GetDocs()) { - printer.Append(d); - } - this->stack_.Reset(); - return printer.GetString(); - } - - std::vector> GetDtypeMatrix(const Plugin& plugin) { - std::vector> matrix; - if (plugin->support_dtypes.size() == 0) { - std::unordered_map dtypes; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - dtypes[i] = plugin->inputs[i]->dtype; - } - matrix.push_back(dtypes); - } else { - ffi::Array templates; - ffi::Array> condidates; - for (const auto& pair : plugin->support_dtypes) { - templates.push_back(pair.first); - condidates.push_back(pair.second); - } - for (const auto& t_dtypes : ArrayUtils::Product(condidates)) { - std::unordered_map dtypes; - for (size_t i = 0; i < templates.size(); i++) { - for (size_t in_idx = 0; in_idx < plugin->inputs.size(); in_idx++) { - if (plugin->inputs[in_idx]->dtype == templates[i]) { - dtypes[in_idx] = t_dtypes[i]; - } - } - } - for (size_t i = 0; i < plugin->inputs.size(); i++) { - if (dtypes.count(i)) { - continue; - } - dtypes[i] = plugin->inputs[i]->dtype; - } - matrix.push_back(dtypes); - } - } - return matrix; - } - - const ffi::Map GetTensorDtypes( - const Plugin& plugin, const std::unordered_map& dtypes) { - ffi::Map tensor_dtypes; - for (const auto& pair : dtypes) { - const ffi::String& ref_dtype = plugin->inputs[pair.first]->dtype; - for (const auto& t : plugin->inputs) { - if (t->dtype == ref_dtype) { - tensor_dtypes.Set(t->name, pair.second); - } - } - for (const auto& t : plugin->outputs) { - if (t->dtype == ref_dtype) { - tensor_dtypes.Set(t->name, pair.second); - } - } - for (const auto& t : plugin->buffers) { - if (t->dtype == ref_dtype) { - tensor_dtypes.Set(t->name, pair.second); - } - } - } - return tensor_dtypes; - } - - /*! \brief Change plugin comment in python*/ - const ffi::String GetPyComment(const Plugin& plugin) { - ffi::String comment = "Python wrapper for " + plugin->name + "\nInputs\n------"; - for (const auto& t : plugin->inputs) { - comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; - } - comment = comment + "\nOutputs\n-------"; - for (const auto& t : plugin->outputs) { - comment = comment + "\n" + t->name + ": " + t->dtype + "\n " + t->describe; - } - if (plugin->attrs.size() > 0) { - comment = comment + "\nAttributes\n-----------"; - for (const auto& a : plugin->attrs) { - comment = comment + "\n" + a->name + ": " + ToPyType(a->type) + "\n " + a->describe; - } - } - return comment; - } - - /*! \brief Get class name for meta attrs*/ - const ffi::String MetaAttrCls(const Plugin& plugin) const { return plugin->name + "MetaAttr"; } - - /*! \brief Get converter name for plugin*/ - const ffi::String ConverterName(const Plugin& plugin) const { return plugin->name + "Converter"; } - - /*! \brief Check if the type is list type. */ - bool IsListType(const ffi::String& type) { return StringUtils::StartsWith(type, "list"); } - - /*! \brief Get type of element. */ - const ffi::String GetEleType(const ffi::String& type) { - if (!IsListType(type)) { - return ""; - } - return StringUtils::Replace(StringUtils::Replace(type, "list(", ""), ")", ""); - } - - /*! \brief Type name in cpp*/ - virtual const ffi::String ToCppType(const ffi::String& type) { - if (IsListType(type)) { - const auto& ele_type = GetEleType(type); - return "std::vector<" + ToCppType(ele_type) + ">"; - } - if (type == "int64") { - return "int64_t"; - } - if (type == "int32" || type == "int") { - return "int32_t"; - } - if (type == "int8") { - return "int8_t"; - } - if (type == "string") { - return "std::string"; - } - return type; - } - - /*! \brief Type name in python*/ - virtual const ffi::String ToPyType(const ffi::String& type) { - if (IsListType(type)) { - const auto& ele_type = GetEleType(type); - return "List[" + ToPyType(ele_type) + "]"; - } - if (type == "int64" || type == "int32" || type == "int" || type == "int8") { - return "int"; - } - if (type == "string") { - return "str"; - } - return type; - } - - /*! - * \brief Compare version with version in config - * 0 for same version, 1 for greater version, -1 for less version - */ - int CompareVersion(size_t major, size_t minor, size_t patch) { - return CommonUtils::CompareVersion(this->config()->version, {major, minor, patch}); - } - - /*! \brief The config of plugin codegen*/ - const std::shared_ptr config() { return config_; } - - /*! \brief The stack of codes*/ - CodeStack stack_; - - private: - std::shared_ptr config_; -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_PLUGIN_BASE_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/codegen_utils.h b/src/contrib/msc/plugin/codegen_utils.h deleted file mode 100644 index 9fd9a0e941ee..000000000000 --- a/src/contrib/msc/plugin/codegen_utils.h +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/codegen_utils.h - * \brief Common utilities for print. - */ -#ifndef TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ -#define TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ - -#include - -#include -#include -#include - -namespace tvm { -namespace contrib { -namespace msc { - -#define PLUGIN_CODEGEN_CONFIG_MEMBERS \ - bool need_convert{false}; \ - bool with_runtime{false}; \ - std::string project_name{"msc_plugin"}; \ - std::string cmake_version{"3.5"}; \ - std::string install_dir; \ - std::vector version{0, 0, 0}; \ - std::vector includes; \ - std::vector libs; \ - std::unordered_map flags; \ - std::unordered_map ops_info; - -#define PLUGIN_CODEGEN_CONFIG_PARSE \ - namespace json = ::tvm::ffi::json; \ - if (auto it = obj.find(ffi::String("need_convert")); it != obj.end()) { \ - need_convert = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("with_runtime")); it != obj.end()) { \ - with_runtime = (*it).second.cast(); \ - } \ - if (auto it = obj.find(ffi::String("cmake_version")); it != obj.end()) { \ - cmake_version = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("project_name")); it != obj.end()) { \ - project_name = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("install_dir")); it != obj.end()) { \ - install_dir = std::string((*it).second.cast()); \ - } \ - if (auto it = obj.find(ffi::String("version")); it != obj.end()) { \ - auto arr = (*it).second.cast(); \ - version.clear(); \ - version.reserve(arr.size()); \ - for (const auto& elem : arr) { \ - version.push_back(static_cast(elem.cast())); \ - } \ - } \ - if (auto it = obj.find(ffi::String("includes")); it != obj.end()) { \ - auto arr = (*it).second.cast(); \ - includes.clear(); \ - includes.reserve(arr.size()); \ - for (const auto& elem : arr) { \ - includes.push_back(std::string(elem.cast())); \ - } \ - } \ - if (auto it = obj.find(ffi::String("libs")); it != obj.end()) { \ - auto arr = (*it).second.cast(); \ - libs.clear(); \ - libs.reserve(arr.size()); \ - for (const auto& elem : arr) { \ - libs.push_back(std::string(elem.cast())); \ - } \ - } \ - if (auto it = obj.find(ffi::String("flags")); it != obj.end()) { \ - auto inner = (*it).second.cast(); \ - flags.clear(); \ - for (const auto& kv : inner) { \ - flags[std::string(kv.first.cast())] = \ - std::string(kv.second.cast()); \ - } \ - } \ - if (auto it = obj.find(ffi::String("ops_info")); it != obj.end()) { \ - auto inner = (*it).second.cast(); \ - ops_info.clear(); \ - for (const auto& kv : inner) { \ - ops_info[std::string(kv.first.cast())] = \ - std::string(kv.second.cast()); \ - } \ - } - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_PLUGIN_CODEGEN_UTILS_H_ diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc deleted file mode 100644 index c7fc27ea5bd1..000000000000 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ /dev/null @@ -1,907 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/tensorrt_codegen.cc - */ -#include "tensorrt_codegen.h" - -#include - -#include -namespace tvm { -namespace contrib { -namespace msc { - -void TensorRTPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { - BasePluginCodeGen::CodeGenAttrDeclare(plugin); - const auto& attr_name = MetaAttrCls(plugin); - // serialize size for attr - stack_.comment("serialize size").func_def(attr_name + "_serialize_size", "size_t"); - // serialize method for attr - stack_.comment("serialize method") - .func_def(attr_name + "_serialize", "char*") - .func_arg("meta_attr", "const " + attr_name + "&") - .func_arg("buffer", "char*"); - // deserialize method for attr - stack_.comment("deserialize method") - .func_def(attr_name + "_deserialize", "const char*") - .func_arg("meta_attr", attr_name + "&") - .func_arg("buffer", "const char*"); - // attr to field - stack_.comment("meta attr to field") - .func_def(attr_name + "_to_fields") - .func_arg("fields", "std::vector&"); - // attr from field - stack_.comment("meta attr from field") - .func_def(attr_name + "_from_fields", "const " + attr_name) - .func_arg("fields", "const PluginField*"); -} - -void TensorRTPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { - const auto& attr_name = MetaAttrCls(plugin); - // serialize size for attr - stack_.func_def(attr_name + "_serialize_size", "size_t").func_start().assign("size", 0, "size_t"); - for (const auto& a : plugin->attrs) { - stack_.comment("attr " + a->name + "(" + a->type + ")"); - if (IsListType(a->type)) { - LOG_FATAL << "attribute type " << a->type << " is not supported"; - const auto& ele_type = GetEleType(a->type); - stack_.assign("size", "size + sizeof(size_t)") - .for_start("a", DocUtils::ToAttrAccess("meta_attr", a->name)) - .assign("size", "size + sizeof(" + ToCppType(ele_type) + ")") - .for_end(); - } else { - stack_.assign("size", "size + sizeof(" + ToCppType(a->type) + ")"); - } - } - stack_.func_end("size"); - // serialize method for attr - stack_.func_def(attr_name + "_serialize", "char*") - .func_arg("meta_attr", "const " + attr_name + "&") - .func_arg("buffer", "char*") - .func_start() - .assign("start", "buffer", "const char*"); - for (const auto& a : plugin->attrs) { - stack_.func_call("TRTUtils::ValToBuffer") - .call_arg("buffer") - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); - } - stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected")) - .line("assert(buffer == start + expected);") - .func_end("buffer"); - // deserialize method for attr - stack_.func_def(attr_name + "_deserialize", "const char*") - .func_arg("meta_attr", attr_name + "&") - .func_arg("buffer", "const char*") - .func_start() - .assign("start", "buffer", "const char*"); - for (const auto& a : plugin->attrs) { - stack_.func_call("TRTUtils::ValFromBuffer") - .call_arg("buffer") - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); - } - stack_.func_call(attr_name + "_serialize_size", DocUtils::ToDeclare("size_t", "expected")) - .line("assert(buffer == start + expected);") - .func_end("buffer"); - // attr to field - stack_.func_def(attr_name + "_to_fields") - .func_arg("fields", "std::vector&") - .func_start(); - for (const auto& a : plugin->attrs) { - stack_.func_call("emplace_back", "", "fields") - .inplace_start("TRTUtils::ToField") - .call_arg(DocUtils::ToStr(a->name)) - .call_arg(DocUtils::ToStr(a->type)) - .inplace_end(); - } - stack_.func_end(); - // attr from field - stack_.func_def(attr_name + "_from_fields", "const " + attr_name) - .func_arg("fields", "const PluginField*") - .func_start() - .declare(attr_name, "meta_attr") - .for_start("i", 0, plugin->attrs.size()); - for (size_t i = 0; i < plugin->attrs.size(); i++) { - const auto& attr = plugin->attrs[i]; - const ffi::String& cond = "strcmp(fields[i].name, \"" + attr->name + "\") == 0"; - if (i == 0) { - stack_.switch_start(cond); - } else { - stack_.switch_case(cond); - } - stack_.func_call("TRTUtils::FromField") - .call_arg(DocUtils::ToIndex("fields", "i")) - .call_arg(DocUtils::ToAttrAccess("meta_attr", attr->name)); - } - stack_.switch_end().for_end().func_end("meta_attr"); -} - -void TensorRTPluginCodeGen::CodeGenOpHeader(const Plugin& plugin) { - BasePluginCodeGen::CodeGenOpHeader(plugin); - stack_.line("using namespace nvinfer1;").line(); -} - -void TensorRTPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { - if (!IsMixPrecision(plugin)) { - // static plugin op - const auto& op_static = OpCls(plugin, false); - stack_.class_def(op_static + " : public IPluginV2").class_start().scope_start("public:"); - CodegenOpCommonMethods(plugin, false, true); - stack_.comment("special methods for " + op_static) - .func_def("getOutputDimensions", "Dims") - .func_decorator("noexcept override") - .func_arg("index", "int") - .func_arg("in_dims", "const Dims*") - .func_arg("n_inputs", "int") - .func_def("configureWithFormat") - .func_decorator("noexcept override") - .func_arg("in_dims", "const Dims*") - .func_arg("n_inputs", "int") - .func_arg("out_dims", "const Dims*") - .func_arg("n_outputs", "int") - .func_arg("dtype", "DataType") - .func_arg("format", "PluginFormat") - .func_arg("max_batch", "int") - .func_def("supportsFormat", "bool") - .func_decorator("const noexcept override") - .func_arg("dtype", "DataType") - .func_arg("format", "PluginFormat") - .func_def("getWorkspaceSize", "size_t") - .func_decorator("const noexcept override") - .func_arg("max_batch", "int") - .func_def("enqueue", "int") - .func_decorator("noexcept override") - .func_arg("batch_size", "int") - .func_arg("inputs", "const void* const*") - .func_arg("outputs", "void* const*") - .func_arg("workspace", "void*") - .func_arg("stream", "cudaStream_t") - .scope_end(); - CodegenOpMembers(plugin, false); - stack_.class_end(); - - // static plugin creator - CodegenCreator(plugin, false, true); - } - // dynamic plugin op - const auto& op_dynamic = OpCls(plugin, true); - stack_.class_def(op_dynamic + " : public IPluginV2DynamicExt") - .class_start() - .scope_start("public:"); - CodegenOpCommonMethods(plugin, true, true); - stack_.comment("special methods for " + op_dynamic) - .func_def("getOutputDataType", "DataType") - .func_decorator("const noexcept override") - .func_arg("index", "int") - .func_arg("in_types", "const DataType*") - .func_arg("n_inputs", "int") - .func_def("getOutputDimensions", "DimsExprs") - .func_decorator("noexcept override") - .func_arg("index", "int") - .func_arg("in_dims", "const DimsExprs*") - .func_arg("n_inputs", "int") - .func_arg("builder", "IExprBuilder&") - .func_def("configurePlugin") - .func_decorator("noexcept override") - .func_arg("in_descs", "const DynamicPluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("out_descs", "const DynamicPluginTensorDesc*") - .func_arg("n_outputs", "int") - .func_def("supportsFormatCombination", "bool") - .func_decorator("noexcept override") - .func_arg("pos", "int") - .func_arg("io_desc", "const PluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("n_outputs", "int") - .func_def("getWorkspaceSize", "size_t") - .func_decorator("const noexcept override") - .func_arg("in_descs", "const PluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("out_descs", "const PluginTensorDesc*") - .func_arg("n_outputs", "int") - .func_def("enqueue", "int") - .func_decorator("noexcept override") - .func_arg("input_descs", "const PluginTensorDesc*") - .func_arg("output_descs", "const PluginTensorDesc*") - .func_arg("inputs", "const void* const*") - .func_arg("outputs", "void* const*") - .func_arg("workspace", "void*") - .func_arg("stream", "cudaStream_t") - .scope_end(); - CodegenOpMembers(plugin, true); - stack_.class_end(); - - // dynamic plugin creator - CodegenCreator(plugin, true, true); -} - -void TensorRTPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { - if (!IsMixPrecision(plugin)) { - // static op - const auto& op_static = OpCls(plugin, false); - CodegenOpCommonMethods(plugin, false, false); - // getOutputDimensions - stack_.func_def(op_static + "::getOutputDimensions", "Dims") - .func_decorator("noexcept") - .func_arg("index", "int") - .func_arg("in_dims", "const Dims*") - .func_arg("n_inputs", "int") - .func_start(); - CodegenOutputInfer(plugin, false); - stack_ - .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"), - DocUtils::ToIndex("output_metas_", "index")) - .func_call("TRTUtils::ToDims", DocUtils::ToDeclare("Dims", "out_dims")) - .call_arg("out_shape") - .func_end("out_dims"); - // configureWithFormat - stack_.func_def(op_static + "::configureWithFormat") - .func_decorator("noexcept") - .func_arg("in_dims", "const Dims*") - .func_arg("n_inputs", "int") - .func_arg("out_dims", "const Dims*") - .func_arg("n_outputs", "int") - .func_arg("dtype", "DataType") - .func_arg("format", "PluginFormat") - .func_arg("max_batch", "int") - .func_start() - .assign("dtype_", "dtype") - .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");"); - CodegenOutputInfer(plugin, false); - stack_.func_end(); - // supportsFormat - stack_.func_def(op_static + "::supportsFormat", "bool") - .func_decorator("const noexcept") - .func_arg("dtype", "DataType") - .func_arg("format", "PluginFormat") - .func_start() - .declare("bool", "support"); - size_t cnt = 0; - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const ffi::String& cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; - if (cnt == 0) { - stack_.switch_start(cond); - } else { - stack_.switch_case(cond); - } - stack_.assign("support", true); - cnt++; - } - stack_.switch_case().assign("support", false).switch_end().func_end("support"); - // getWorkspaceSize - stack_.func_def(op_static + "::getWorkspaceSize", "size_t") - .func_decorator("const noexcept") - .func_arg("max_batch", "int") - .func_start() - .assign("size", 0, "size_t"); - if (plugin->externs.count("infer_buffer")) { - CodegenBufferInfer(plugin); - } - stack_.func_end("size"); - // enqueue - stack_.func_def(op_static + "::enqueue", "int") - .func_decorator("noexcept") - .func_arg("batch_size", "int") - .func_arg("inputs", "const void* const*") - .func_arg("outputs", "void* const*") - .func_arg("workspace", "void*") - .func_arg("stream", "cudaStream_t") - .func_start(); - CodegenEnqueue(plugin, false); - stack_.func_end(0); - - // static creator - CodegenCreator(plugin, false, false); - } - // dynamic op - const auto& op_dynamic = OpCls(plugin, true); - CodegenOpCommonMethods(plugin, true, false); - // getOutputDataType - stack_.func_def(op_dynamic + "::getOutputDataType", "DataType") - .func_decorator("const noexcept") - .func_arg("index", "int") - .func_arg("in_types", "const DataType*") - .func_arg("n_inputs", "int") - .func_start() - .declare("DataType", "dtype"); - for (size_t i = 0; i < plugin->outputs.size(); i++) { - if (i == 0) { - stack_.switch_start("index == " + std::to_string(i)); - } else { - stack_.switch_case("index == " + std::to_string(i)); - } - int ref = plugin->FindDtypeRefIdx(plugin->outputs[i]); - if (ref >= 0) { - stack_.assign("dtype", DocUtils::ToIndex("in_types", ref)); - } else { - stack_.func_call("TRTUtils::ToDataType", "dtype") - .call_arg(DocUtils::ToStr(plugin->outputs[i]->dtype)); - } - } - stack_.switch_end().func_end("dtype"); - // getOutputDimensions - stack_.func_def(op_dynamic + "::getOutputDimensions", "DimsExprs") - .func_decorator("noexcept") - .func_arg("index", "int") - .func_arg("in_dims", "const DimsExprs*") - .func_arg("n_inputs", "int") - .func_arg("builder", "IExprBuilder&") - .func_start(); - CodegenOutputInfer(plugin, false); - stack_ - .func_call("shape", DocUtils::ToDeclare("MetaShape", "out_shape"), - DocUtils::ToIndex("output_metas_", "index")) - .func_call("TRTUtils::ToDimsExprs", DocUtils::ToDeclare("DimsExprs", "out_dims")) - .call_arg("out_shape") - .call_arg("builder") - .func_end("out_dims"); - // configurePlugin - stack_.func_def(op_dynamic + "::configurePlugin") - .func_decorator("noexcept") - .func_arg("in_descs", "const DynamicPluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("out_descs", "const DynamicPluginTensorDesc*") - .func_arg("n_outputs", "int") - .func_start() - .line("assert(n_outputs == " + std::to_string(plugin->outputs.size()) + ");"); - CodegenOutputInfer(plugin, true); - stack_.func_end(); - // supportsFormatCombination - stack_.func_def(op_dynamic + "::supportsFormatCombination", "bool") - .func_decorator("noexcept") - .func_arg("pos", "int") - .func_arg("io_desc", "const PluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("n_outputs", "int") - .func_start() - .declare("bool", "support"); - size_t cnt = 0; - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - ffi::String cond; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - cond = cond + "io_desc[" + std::to_string(i) + "].type == TRTUtils::ToDataType(\"" + - dtypes.at(i) + "\")"; - cond = cond + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - if (cnt == 0) { - stack_.switch_start(cond); - } else { - stack_.switch_case(cond); - } - stack_.assign("support", true); - cnt++; - } - stack_.switch_case().assign("support", false).switch_end().func_end("support"); - // getWorkspaceSize - stack_.func_def(op_dynamic + "::getWorkspaceSize", "size_t") - .func_decorator("const noexcept") - .func_arg("in_descs", "const PluginTensorDesc*") - .func_arg("n_inputs", "int") - .func_arg("out_descs", "const PluginTensorDesc*") - .func_arg("n_outputs", "int") - .func_start() - .assign("size", 0, "size_t"); - if (plugin->externs.count("infer_buffer")) { - CodegenBufferInfer(plugin); - } - stack_.func_end("size"); - // enqueue - stack_.func_def(op_dynamic + "::enqueue", "int") - .func_decorator("noexcept") - .func_arg("input_descs", "const PluginTensorDesc*") - .func_arg("output_descs", "const PluginTensorDesc*") - .func_arg("inputs", "const void* const*") - .func_arg("outputs", "void* const*") - .func_arg("workspace", "void*") - .func_arg("stream", "cudaStream_t") - .func_start(); - CodegenEnqueue(plugin, true); - stack_.func_end(0); - - // dynamic creator - CodegenCreator(plugin, true, false); -} - -void TensorRTPluginCodeGen::CodeGenCmake(const std::set& devices) { - ffi::Map flags; - flags.Set("PLUGIN_SUPPORT_TENSORRT", ""); - flags.Set("TRT_MAJOR", std::to_string(config()->version[0])); - flags.Set("TRT_MINOR", std::to_string(config()->version[1])); - flags.Set("TRT_PATCH", std::to_string(config()->version[2])); - CodeGenPreCmake(devices, flags); - stack_ - .line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + - " PATH_SUFFIXES include)") - .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + - " PATH_SUFFIXES lib)") - .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-terminate\")"); - ffi::Array includes, libs; - includes.push_back("${TRT_INCLUDE_DIR}"); - libs.push_back("${TRT_LIBS}"); - CodeGenPostCmake(devices, includes, libs); -} - -void TensorRTPluginCodeGen::CodeGenManagerMethods() { - BasePluginCodeGen::CodeGenManagerMethods(); - stack_.func_def("setup") - .func_arg("self", "object") - .func_start() - .for_start("lib", "os.listdir(self._lib_folder)") - .assign("lib_file", "os.path.join(self._lib_folder, lib)") - .func_call("CDLL", "", "ctypes") - .call_arg("lib_file") - .for_end() - .func_end(); -} - -void TensorRTPluginCodeGen::CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, - bool in_declare) { - const auto& op_cls = OpCls(plugin, dynamic); - const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; - if (in_declare) { - stack_.comment("common methods for " + op_cls); - stack_.constructor_def(op_cls).constructor_arg("name", "const std::string&"); - for (const auto& a : plugin->attrs) { - stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&"); - } - stack_.constructor_arg("layouts", "const std::vector&") - .constructor_def(op_cls) - .constructor_arg("name", "const std::string&") - .constructor_arg("buffer", "const void*") - .constructor_arg("length", "size_t") - .assign(op_cls + "()", "delete") - .line() - .constructor_def("~" + op_cls) - .func_def("getSerializationSize", "size_t") - .func_decorator("const noexcept override") - .func_def("serialize") - .func_decorator("const noexcept override") - .func_arg("buffer", "void*") - .func_def("getPluginType", "const char*") - .func_decorator("const noexcept override") - .func_def("getPluginVersion", "const char*") - .func_decorator("const noexcept override") - .func_def("getPluginNamespace", "const char*") - .func_decorator("const noexcept override") - .func_def("getNbOutputs", "int") - .func_decorator("const noexcept override") - .func_def("setPluginNamespace") - .func_decorator("noexcept override") - .func_arg("name_space", "const char*") - .func_def("initialize", "int") - .func_decorator("noexcept override") - .func_def("terminate") - .func_decorator("noexcept override") - .func_def("destroy") - .func_decorator("noexcept override") - .func_def("clone", plugin_cls + "*") - .func_decorator("const noexcept override"); - } else { - const auto& attr_name = MetaAttrCls(plugin); - // constructor from attrs - stack_.constructor_def(op_cls + "::" + op_cls).constructor_arg("name", "const std::string&"); - for (const auto& a : plugin->attrs) { - stack_.constructor_arg(a->name, "const " + ToCppType(a->type) + "&"); - } - stack_.constructor_arg("layouts", "const std::vector&") - .constructor_start() - .assign("name_", "name"); - for (const auto& a : plugin->attrs) { - stack_.assign(DocUtils::ToAttrAccess("meta_attr_", a->name), a->name); - } - stack_.line("assert(layouts.size() == " + std::to_string(plugin->inputs.size()) + ");") - .assign("layouts_", "layouts"); - stack_.constructor_end(); - // constructor from data - stack_.constructor_def(op_cls + "::" + op_cls) - .constructor_arg("name", "const std::string&") - .constructor_arg("buffer", "const void*") - .constructor_arg("length", "size_t") - .constructor_start() - .assign("name_", "name") - .func_call("static_cast", DocUtils::ToDeclare("const char*", "char_buf")) - .call_arg("buffer") - .assign("start_buf", "char_buf", "const char*") - .func_call(attr_name + "_deserialize", "char_buf") - .call_arg("meta_attr_") - .call_arg("char_buf") - .func_call("TRTUtils::ValFromBuffer") - .call_arg("char_buf") - .call_arg("dtype_") - .func_call("TRTUtils::ValFromBuffer") - .call_arg("char_buf") - .call_arg("layouts_") - .line("assert(layouts_.size() == " + std::to_string(plugin->inputs.size()) + ");") - .line("assert(char_buf == (start_buf + length));") - .constructor_end(); - // deconstructor - stack_.constructor_def(op_cls + "::~" + op_cls) - .constructor_start() - .comment("ignore deconstruct of " + op_cls) - .constructor_end(); - // getSerializationSize - stack_.func_def(op_cls + "::getSerializationSize", "size_t") - .func_decorator("const noexcept") - .func_start() - .assign("size", attr_name + "_serialize_size()", "size_t") - .assign("size", "size + sizeof(dtype_)") - .assign("size", "size + sizeof(size_t)") - .for_start("layout", "layouts_") - .assign("size", "size + sizeof(size_t) + layout.size() * sizeof(char)") - .for_end() - .func_end("size"); - // serialize - stack_.func_def(op_cls + "::serialize") - .func_decorator("const noexcept") - .func_arg("buffer", "void*") - .func_start() - .func_call("static_cast", DocUtils::ToDeclare("char*", "char_buf")) - .call_arg("buffer") - .assign("start_buf", "char_buf", "const char*") - .func_call(attr_name + "_serialize", "char_buf") - .call_arg("meta_attr_") - .call_arg("char_buf") - .func_call("TRTUtils::ValToBuffer") - .call_arg("char_buf") - .call_arg("dtype_") - .func_call("TRTUtils::ValToBuffer") - .call_arg("char_buf") - .call_arg("layouts_") - .line("assert(char_buf == (start_buf + getSerializationSize()));") - .func_end(); - // getPluginType - const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); - stack_.func_def(op_cls + "::getPluginType", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_end(DocUtils::ToStr(plugin_type)); - // getPluginVersion - stack_.func_def(op_cls + "::getPluginVersion", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_end(DocUtils::ToStr("1")); - // getPluginNamespace - stack_.func_def(op_cls + "::getPluginNamespace", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_call("c_str", DocUtils::ToDeclare("const char*", "name"), - DocUtils::ToDoc("name_space_")) - .func_end("name"); - // getNbOutputs - stack_.func_def(op_cls + "::getNbOutputs", "int") - .func_decorator("const noexcept") - .func_start() - .func_end(plugin->outputs.size()); - // setPluginNamespace - stack_.func_def(op_cls + "::setPluginNamespace") - .func_decorator("noexcept") - .func_arg("name_space", "const char*") - .func_start() - .assign("name_space_", "name_space") - .func_end(); - // initialize - stack_.func_def(op_cls + "::initialize", "int") - .func_decorator("noexcept") - .func_start() - .func_end(0); - // terminate - stack_.func_def(op_cls + "::terminate") - .func_decorator("noexcept") - .func_start() - .comment("Ignore teminate for " + plugin->name) - .func_end(); - // destroy - stack_.func_def(op_cls + "::destroy") - .func_decorator("noexcept") - .func_start() - .line("delete this;") - .func_end(); - // clone - stack_.func_def(op_cls + "::clone", plugin_cls + "*") - .func_decorator("const noexcept") - .func_start() - .func_call("new " + op_cls, DocUtils::ToDeclare(plugin_cls + "*", "plugin")) - .call_arg("name_"); - for (const auto& a : plugin->attrs) { - stack_.call_arg(DocUtils::ToAttrAccess("meta_attr_", a->name)); - } - stack_.call_arg("layouts_").func_end("plugin"); - } -} - -void TensorRTPluginCodeGen::CodegenOpMembers(const Plugin& plugin, bool dynamic) { - stack_.scope_start("private:") - .declare("std::string", "name_") - .declare("std::string", "name_space_") - .declare("DataType", "dtype_", 0, false) - .declare_arg("DataType::kFLOAT") - .declare(MetaAttrCls(plugin), "meta_attr_") - .declare("std::vector", "layouts_") - .declare("std::vector", "input_metas_") - .declare("std::vector", "output_metas_"); - if (plugin->externs.count("infer_buffer")) { - stack_.declare("std::vector", "buffer_metas_"); - } - stack_.scope_end().line(); -} - -void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare) { - const auto& creator_cls = CreatorCls(plugin, dynamic); - const ffi::String& plugin_cls = dynamic ? "IPluginV2DynamicExt" : "IPluginV2"; - if (in_declare) { - stack_.class_def(creator_cls + " : public IPluginCreator") - .class_start() - .scope_start("public:") - .constructor_def(creator_cls) - .func_def("getPluginName", "const char*") - .func_decorator("const noexcept override") - .func_def("getPluginVersion", "const char*") - .func_decorator("const noexcept override") - .func_def("getPluginNamespace", "const char*") - .func_decorator("const noexcept override") - .func_def("getFieldNames", "const PluginFieldCollection*") - .func_decorator("noexcept override") - .func_def("setPluginNamespace") - .func_decorator("noexcept override") - .func_arg("name_space", "const char*") - .func_def("createPlugin", plugin_cls + "*") - .func_decorator("noexcept override") - .func_arg("name", "const char*") - .func_arg("collection", "const PluginFieldCollection*") - .func_def("deserializePlugin", plugin_cls + "*") - .func_decorator("noexcept override") - .func_arg("name", "const char*") - .func_arg("data", "const void*") - .func_arg("length", "size_t") - .scope_end() - .scope_start("private:") - .declare("static PluginFieldCollection", "collection_") - .declare("static std::vector", "fields_") - .declare("std::string", "name_space_") - .scope_end() - .line() - .class_end(); - } else { - const ffi::String& attr_name = MetaAttrCls(plugin); - // static members - stack_.comment("static members and register for " + plugin->name) - .declare("PluginFieldCollection", creator_cls + "::collection_") - .declare("std::vector", creator_cls + "::fields_") - .func_call("REGISTER_TENSORRT_PLUGIN") - .call_arg(creator_cls) - .line(); - // constructor - stack_.constructor_def(creator_cls + "::" + creator_cls) - .constructor_start() - .func_call(attr_name + "_to_fields") - .call_arg("fields_"); - for (const auto& t : plugin->inputs) { - stack_.func_call("emplace_back", "", "fields_") - .inplace_start("TRTUtils::ToField") - .call_arg(DocUtils::ToStr("layout_" + t->name)) - .call_arg(DocUtils::ToStr("string")) - .inplace_end(); - } - const auto& nb_fields_doc = DocUtils::ToAttrAccess("collection_", "nbFields"); - const auto& fields_doc = DocUtils::ToAttrAccess("collection_", "fields"); - stack_.func_call("size", nb_fields_doc, DocUtils::ToDoc("fields_")) - .func_call("data", fields_doc, DocUtils::ToDoc("fields_")) - .constructor_end(); - // getPluginName - const ffi::String& plugin_type = plugin->name + (dynamic ? "_dynamic" : ""); - stack_.func_def(creator_cls + "::getPluginName", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_end(DocUtils::ToStr(plugin_type)); - // getPluginVersion - stack_.func_def(creator_cls + "::getPluginVersion", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_end(DocUtils::ToStr("1")); - // getPluginNamespace - stack_.func_def(creator_cls + "::getPluginNamespace", "const char*") - .func_decorator("const noexcept") - .func_start() - .func_call("c_str", DocUtils::ToDeclare("const char*", "name"), - DocUtils::ToDoc("name_space_")) - .func_end("name"); - // getFieldNames - stack_.func_def(creator_cls + "::getFieldNames", "const PluginFieldCollection*") - .func_decorator("noexcept") - .func_start() - .func_end("&collection_"); - // setPluginNamespace - stack_.func_def(creator_cls + "::setPluginNamespace") - .func_decorator("noexcept") - .func_arg("name_space", "const char*") - .func_start() - .assign("name_space_", "name_space") - .func_end(); - // createPlugin - size_t fields_size = plugin->attrs.size() + plugin->inputs.size(); - const auto& op_cls = OpCls(plugin, dynamic); - stack_.func_def(creator_cls + "::createPlugin", plugin_cls + "*") - .func_decorator("noexcept") - .func_arg("name", "const char*") - .func_arg("collection", "const PluginFieldCollection*") - .func_start() - .line("assert(collection->nbFields == " + std::to_string(fields_size) + ");") - .assign("fields", DocUtils::ToAttrAccess(DocUtils::ToPtr("collection"), "fields"), - "const PluginField*") - .func_call(attr_name + "_from_fields", DocUtils::ToDeclare("const auto&", "meta_attr")) - .call_arg("fields") - .declare("std::vector", "layouts") - .func_call("resize", "", "layouts") - .call_arg(plugin->inputs.size()) - .for_start("i", plugin->attrs.size(), fields_size); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const auto& tensor = plugin->inputs[i]; - const ffi::String& cond = "strcmp(fields[i].name, \"layout_" + tensor->name + "\") == 0"; - if (i == 0) { - stack_.switch_start(cond); - } else { - stack_.switch_case(cond); - } - stack_.func_call("TRTUtils::FromField") - .call_arg(DocUtils::ToIndex("fields", "i")) - .call_arg(DocUtils::ToIndex("layouts", i)); - } - stack_.switch_end() - .for_end() - .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin")) - .call_arg("name"); - for (const auto& a : plugin->attrs) { - stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); - } - stack_.call_arg("layouts") - .func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin")) - .inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_")) - .inplace_end() - .func_end("plugin"); - // deserializePlugin - stack_.func_def(creator_cls + "::deserializePlugin", plugin_cls + "*") - .func_decorator("noexcept") - .func_arg("name", "const char*") - .func_arg("data", "const void*") - .func_arg("length", "size_t") - .func_start() - .func_call("new " + op_cls, DocUtils::ToDeclare(op_cls + "*", "plugin")) - .call_arg("name") - .call_arg("data") - .call_arg("length") - .func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin")) - .inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_")) - .inplace_end() - .func_end("plugin"); - } -} - -void TensorRTPluginCodeGen::CodegenOutputInfer(const Plugin& plugin, bool as_desc) { - ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; - stack_.line("assert(n_inputs == " + std::to_string(plugin->inputs.size()) + ");") - .func_call("resize", "", "input_metas_") - .call_arg(plugin->inputs.size()) - .for_start("i", 0, plugin->inputs.size()) - .func_call("TRTUtils::ToMetaTensor", DocUtils::ToIndex("input_metas_", "i")); - if (as_desc) { - stack_.call_arg(DocUtils::ToIndex("in_descs", "i")); - } else { - stack_.call_arg(DocUtils::ToIndex("in_dims", "i")).call_arg("dtype_"); - } - stack_.call_arg(DocUtils::ToIndex("layouts_", "i")).for_end(); - CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas_"); -} - -void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { - ffi::Array infer_args{"input_metas_", "meta_attr_", "false"}; - CodeGenSafeCall(plugin->externs["infer_buffer"], infer_args, "buffer_metas_"); - stack_.for_start("b", "buffer_metas_") - .assign("size", "size + max_batch * b.size(false)") - .for_end(); -} - -void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { - TVM_FFI_ICHECK(plugin->externs.count("cuda_compute")) - << "cuda_compute is needed fo TensorRT plugin"; - auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor, - const ffi::Map& dtypes, - size_t idx, const ffi::String& collect) { - const ffi::String& t_name = "d_" + tensor->name; - const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; - const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; - stack_.func_call("TRTUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)); - const auto& t_meta = DocUtils::ToIndex(collect + "_metas_", idx); - if (dynamic) { - stack_.call_arg(t_meta).call_arg(DocUtils::ToIndex(collect + "_descs", idx)); - } else { - stack_.call_arg(t_meta).call_arg("batch_size"); - } - if (collect == "input") { - stack_.call_arg(DocUtils::ToIndex("inputs", idx)); - } else if (collect == "output") { - stack_.call_arg(DocUtils::ToIndex("outputs", idx)); - } else { - stack_.call_arg("workspace + offset"); - } - return t_name; - }; - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - ffi::Array compute_args; - ffi::String dtype_cond = ""; - if (dynamic) { - for (size_t i = 0; i < plugin->inputs.size(); i++) { - dtype_cond = dtype_cond + "input_descs[" + std::to_string(i) + - "].type == TRTUtils::ToDataType(\"" + dtypes.at(i) + "\")"; - dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - } else { - dtype_cond = "dtype_ == TRTUtils::ToDataType(\"" + dtypes.at(0) + "\")"; - } - // prepare compute datas - stack_.cond_if(dtype_cond).comment("prepare compute datas"); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); - compute_args.push_back(t_name); - } - for (size_t i = 0; i < plugin->outputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); - compute_args.push_back(t_name); - } - if (plugin->buffers.size() > 0) { - stack_.assign("offset", 0, "size_t"); - for (size_t i = 0; i < plugin->buffers.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "buffer"); - compute_args.push_back(t_name); - const ffi::String& size_name = "size_" + plugin->buffers[i]->name; - stack_ - .func_call("size", DocUtils::ToDeclare("size_t", size_name), - DocUtils::ToIndex("buffer_metas_", i)) - .call_arg(false) - .assign("offset", "offset + batch_size * " + size_name); - } - } - compute_args.push_back("meta_attr_"); - compute_args.push_back("stream"); - CodeGenSafeCall(plugin->externs["cuda_compute"], compute_args); - stack_.cond_end(); - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.plugin.GetTensorRTPluginSources", - [](const ffi::String& codegen_config, const ffi::String& print_config, - const ffi::String& codegen_type) -> ffi::Map { - TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return ffi::Map(); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/plugin/tensorrt_codegen.h b/src/contrib/msc/plugin/tensorrt_codegen.h deleted file mode 100644 index 839a5f0927c6..000000000000 --- a/src/contrib/msc/plugin/tensorrt_codegen.h +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/tensorrt_codegen.h - * \brief Codegen for tensorrt plugin. - */ -#ifndef TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ -#define TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ - -#include -#include - -#include "base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for tensorrt plugin - */ -struct TensorRTPluginCodeGenConfig { - std::string tensorrt_root{"/usr/local/cuda"}; - PLUGIN_CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("tensorrt_root")); it != obj.end()) { - tensorrt_root = std::string((*it).second.cast()); - } - PLUGIN_CODEGEN_CONFIG_PARSE - } -}; - -class TensorRTPluginCodeGen : public BasePluginCodeGen { - public: - /*! - * \brief The constructor of TensorRTPluginCodeGen - * \param config the options for codegen. - */ - explicit TensorRTPluginCodeGen(const std::string& config = "") - : BasePluginCodeGen(config) {} - - protected: - /*! \brief Codegen plugin attr declare*/ - void CodeGenAttrDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin attr define*/ - void CodeGenAttrDefine(const Plugin& plugin) final; - - /*! \brief Header of plugin files*/ - void CodeGenOpHeader(const Plugin& plugin) final; - - /*! \brief Codegen plugin op declare*/ - void CodeGenOpDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin op define*/ - void CodeGenOpDefine(const Plugin& plugin) final; - - /*! \brief Codegen CMake file*/ - void CodeGenCmake(const std::set& devices) final; - - /*! \brief Codegen manager methods*/ - void CodeGenManagerMethods() final; - - private: - /*! \brief Op class name of plugin*/ - const ffi::String OpCls(const Plugin& plugin, bool dynamic) const { - return plugin->name + (dynamic ? "DynamicPlugin" : "Plugin"); - } - - /*! \brief Creator class name of plugin*/ - const ffi::String CreatorCls(const Plugin& plugin, bool dynamic) const { - return plugin->name + (dynamic ? "DynamicCreator" : "Creator"); - } - - bool IsMixPrecision(const Plugin& plugin) { - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - ffi::String ref_dtype = ""; - for (const auto& pair : dtypes) { - if (ref_dtype.size() == 0) { - ref_dtype = pair.second; - } else if (ref_dtype != pair.second) { - return true; - } - } - } - return false; - } - - /*! \brief codegen plugin op common methods declare*/ - void CodegenOpCommonMethods(const Plugin& plugin, bool dynamic, bool in_declare); - - /*! \brief codegen plugin op members define*/ - void CodegenOpMembers(const Plugin& plugin, bool dynamic); - - /*! \brief codegen plugin creator*/ - void CodegenCreator(const Plugin& plugin, bool dynamic, bool in_declare); - - /*! \brief codegen infer output func*/ - void CodegenOutputInfer(const Plugin& plugin, bool as_desc = false); - - /*! \brief codegen infer buffer func*/ - void CodegenBufferInfer(const Plugin& plugin); - - /*! \brief codegen enqueue func*/ - void CodegenEnqueue(const Plugin& plugin, bool dynamic); -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_PLUGIN_TENSORRT_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc deleted file mode 100644 index f43fd1c1a6b3..000000000000 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ /dev/null @@ -1,517 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/torch_codegen.cc - */ -#include "torch_codegen.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -void TorchPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { - BasePluginCodeGen::CodeGenAttrDeclare(plugin); - const auto& attr_name = MetaAttrCls(plugin); - // serialize method for attr - stack_.comment("serialize method") - .func_def(attr_name + "_serialize", "std::vector") - .func_arg("meta_attr", "const " + attr_name + "&"); - // deserialize method for attr - stack_.comment("deserialize method") - .func_def(attr_name + "_deserialize") - .func_arg("attrs", "const std::vector&") - .func_arg("meta_attr", attr_name + "&"); -} - -void TorchPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { - const auto& attr_name = MetaAttrCls(plugin); - // serialize method for attr - stack_.func_def(attr_name + "_serialize", "std::vector") - .func_arg("meta_attr", "const " + attr_name + "&") - .func_start() - .declare("std::vector", "attrs"); - for (const auto& a : plugin->attrs) { - stack_.func_call("push_back", "", "attrs") - .inplace_start("SerializeUtils::ToString") - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) - .inplace_end(); - } - stack_.func_end("attrs"); - // deserialize method for attr - stack_.func_def(attr_name + "_deserialize") - .func_arg("attrs", "const std::vector&") - .func_arg("meta_attr", attr_name + "&") - .func_start(); - for (size_t i = 0; i < plugin->attrs.size(); i++) { - stack_.func_call("SerializeUtils::FromString") - .call_arg(DocUtils::ToIndex("attrs", i)) - .call_arg(DocUtils::ToAttrAccess("meta_attr", plugin->attrs[i]->name)); - } - stack_.func_end(); -} - -void TorchPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { - stack_.struct_start(plugin->name + " : torch::CustomClassHolder"); - // constructor - stack_.constructor_def(plugin->name).constructor_arg("attrs", "const std::vector&"); - // serialize method - stack_.comment("serialize method").func_def("serialize", "const std::vector"); - // compute method - stack_.comment("main compute") - .func_def("compute", "std::vector") - .func_arg("input_tensors", "const std::vector&"); - // members - stack_.comment("members") - .declare(MetaAttrCls(plugin), "meta_attr_") - .declare("std::vector", "layouts_") - .declare("std::string", "name_"); - stack_.struct_end(); - // entry method - stack_.comment("Entry method for plugin " + plugin->name) - .func_def(EntryName(plugin), "std::vector") - .func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&"); - for (const auto& input : plugin->inputs) { - stack_.func_arg(input->name, "const torch::Tensor&"); - } - for (const auto& a : plugin->attrs) { - stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&"); - } - stack_.func_arg("name", "const std::string&"); -} - -void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { - const auto& attr_name = MetaAttrCls(plugin); - // define constructor - stack_.constructor_def(plugin->name + "::" + plugin->name) - .constructor_arg("attrs", "const std::vector&") - .constructor_start() - .comment("get attributes") - .func_call(attr_name + "_deserialize") - .call_arg("attrs") - .call_arg("meta_attr_") - .comment("get extra info") - .assign("name_", DocUtils::ToIndex("attrs", plugin->attrs.size())) - .for_start("i", 1 + plugin->attrs.size(), 1 + plugin->attrs.size() + plugin->inputs.size()) - .func_call("push_back", "", "layouts_") - .inplace_start("MetaLayout") - .call_arg(DocUtils::ToIndex("attrs", "i")) - .inplace_end() - .for_end() - .constructor_end(); - // define serialize - stack_.func_def(plugin->name + "::serialize", "const std::vector") - .func_start() - .assign("attrs", attr_name + "_serialize(meta_attr_)", "std::vector") - .func_call("push_back", "", "attrs") - .call_arg("name_") - .for_start("i", 0, plugin->inputs.size()) - .func_call("push_back", "", "attrs") - .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("layouts_", "i"), "name()")) - .for_end() - .func_end("attrs"); - // compute method - stack_.func_def(plugin->name + "::compute", "std::vector") - .func_arg("input_tensors", "const std::vector&") - .func_start() - .declare("std::vector", "output_tensors"); - if (plugin->externs.count("infer_buffer")) { - stack_.declare("std::vector", "buffer_tensors"); - } - stack_.line() - .comment("extract meta inputs") - .declare("std::vector", "input_metas") - .for_start("i", 0, plugin->inputs.size()) - .func_call("push_back", "", "input_metas") - .inplace_start("TorchUtils::ToMetaTensor") - .call_arg(DocUtils::ToIndex("input_tensors", "i")) - .call_arg(DocUtils::ToIndex("layouts_", "i")) - .inplace_end() - .for_end(); - // malloc outputs and buffers - TVM_FFI_ICHECK(plugin->externs.count("infer_output")) << "Can not find extern shape"; - CodeGenMalloc(plugin, plugin->outputs, "output"); - if (plugin->externs.count("infer_buffer")) { - CodeGenMalloc(plugin, plugin->buffers, "buffer"); - } - // do the compute - ffi::String device_cond = ""; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { - device_cond = device_cond + "input_tensors[" + std::to_string(i) + "].is_cuda()"; - } else { - device_cond = device_cond + "!input_tensors[" + std::to_string(i) + "].is_cuda()"; - } - device_cond = device_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - stack_.line().comment("do the compute").cond_if(device_cond); - CodeGenCompute(plugin, "cuda"); - stack_.cond_else(); - CodeGenCompute(plugin, "cpu"); - stack_.cond_end(); - stack_.func_end("output_tensors"); - - // register op - const auto& entry_name = EntryName(plugin); - stack_.func_def(entry_name, "std::vector") - .func_arg("instance", "const c10::intrusive_ptr<" + plugin->name + ">&"); - for (const auto& input : plugin->inputs) { - stack_.func_arg(input->name, "const torch::Tensor&"); - } - for (const auto& a : plugin->attrs) { - stack_.func_arg(a->name, "const " + ToTorchType(a->type) + "&"); - } - stack_.func_arg("name", "const std::string&"); - stack_.func_start().declare("std::vector", "inputs", 0, false); - for (const auto& input : plugin->inputs) { - stack_.declare_arg(input->name); - } - const auto& outputs_doc = DocUtils::ToDeclare("std::vector", "outputs"); - stack_.func_call("compute", outputs_doc, DocUtils::ToPtr("instance")).call_arg("inputs"); - stack_.func_end("outputs"); - stack_.comment("Bind plugin " + plugin->name + " to python") - .func_def("TORCH_LIBRARY", DocSymbol::Empty()) - .func_arg(plugin->name, DocSymbol::Empty()) - .func_arg("m", DocSymbol::Empty()) - .func_start() - .lambda_def("serialize") - .lambda_arg("op", "const c10::intrusive_ptr<" + plugin->name + ">&") - .lambda_start() - .lambda_end(DocUtils::ToAttrAccess(DocUtils::ToPtr("op"), "serialize()")) - .lambda_def("deserialize") - .lambda_arg("state", "std::vector") - .lambda_start() - .lambda_end("c10::make_intrusive<" + plugin->name + ">(std::move(state))") - .func_call("class_<" + plugin->name + ">", "", "m") - .call_arg(DocUtils::ToStr(plugin->name)) - .method_call("def", true) - .call_arg("torch::init>()") - .method_call("def", true) - .call_arg(DocUtils::ToStr("compute")) - .call_arg("&" + plugin->name + "::compute") - .method_call("def_pickle", true) - .call_arg("serialize") - .call_arg("deserialize") - .func_call("def", "", "m") - .call_arg(DocUtils::ToStr(entry_name)) - .call_arg(entry_name) - .func_end(); -} - -void TorchPluginCodeGen::CodeGenCmake(const std::set& devices) { - ffi::Map flags; - flags.Set("PLUGIN_SUPPORT_TORCH", ""); - CodeGenPreCmake(devices, flags); - stack_.line() - .line("set(CMAKE_CXX_STANDARD 17)") - .line("list(APPEND CMAKE_PREFIX_PATH \"" + config()->torch_prefix + "\")") - .line("find_package(Torch REQUIRED)"); - ffi::Array includes, libs; - libs.push_back("${TORCH_LIBRARIES}"); - CodeGenPostCmake(devices, includes, libs); -} - -void TorchPluginCodeGen::CodeGenManagerDepends() { - BasePluginCodeGen::CodeGenManagerDepends(); - stack_.line("import torch") - .line() - .func_def("to_string", "str") - .func_arg("value", "Any") - .func_start() - .switch_start("isinstance(value, (list, tuple))") - .assign("str_value", "\",\".join([str(len(value))] + [to_string(v) for v in value])") - .switch_case("isinstance(value, bool)") - .assign("str_value", "\"1\" if value else \"0\"") - .switch_case() - .assign("str_value", "str(value)") - .switch_end() - .func_end("str_value"); -} - -void TorchPluginCodeGen::CodeGenManagerMethods() { - BasePluginCodeGen::CodeGenManagerMethods(); - // libs_loaded method - stack_.func_def("libs_loaded") - .func_arg("self", "object") - .func_start() - .assign("loaded_libs", "set()") - .assign("loaded", DocUtils::ToDoc(false)) - .for_start("lib", "torch.classes.loaded_libraries") - .func_call("add", "", "loaded_libs") - .inplace_start("os.path.basename") - .call_arg("lib") - .inplace_end() - .for_end() - .for_start("lib", "os.listdir(self._lib_folder)") - .cond_if("lib in loaded_libs") - .assign("loaded", DocUtils::ToDoc(true)) - .line("break") - .cond_end() - .for_end() - .func_end("loaded"); - // setup method - stack_.func_def("setup") - .func_arg("self", "object") - .func_start() - .for_start("lib", "os.listdir(self._lib_folder)") - .assign("lib_file", "os.path.join(self._lib_folder, lib)") - .cond_if("\"" + config()->project_name + "\" in lib") - .func_call("load_library", "", "torch.classes") - .call_arg("lib_file") - .cond_else() - .func_call("CDLL", "", "ctypes") - .call_arg("lib_file") - .cond_end() - .for_end() - .func_end(); -} - -void TorchPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { - const auto& entry_name = EntryName(plugin); - stack_.func_def(plugin->name).func_arg("self", "object"); - for (const auto& attr : plugin->attrs) { - stack_.func_arg(attr->name, attr->type, attr->default_value); - } - stack_.func_arg("name", "str", "\"" + plugin->name + "\"") - .func_arg("layouts", "List[str]", "None") - .func_start() - .class_def(plugin->name + "(torch.nn.Module)") - .class_start(); - // init method - stack_.func_def("__init__").func_arg("self", "torch.nn.Module"); - for (const auto& attr : plugin->attrs) { - stack_.func_arg(attr->name, attr->type, attr->default_value); - } - stack_.func_arg("name", "str", "\"" + plugin->name + "\"") - .func_arg("layouts", "List[str]", "None") - .func_start() - .func_call("__init__", "", "super()"); - for (const auto& attr : plugin->attrs) { - stack_.assign(DocUtils::ToAttrAccess("self", attr->name), attr->name); - } - stack_.assign(DocUtils::ToAttrAccess("self", "name"), "name") - .cond_if("layouts is None") - .assign(DocUtils::ToAttrAccess("self", "layouts"), - "[\"\"] * " + std::to_string(plugin->inputs.size())) - .cond_else() - .assign(DocUtils::ToAttrAccess("self", "layouts"), "layouts") - .cond_end() - .line() - .assign("attr_strs", "[]"); - for (const auto& attr : plugin->attrs) { - stack_.func_call("append", "", "attr_strs") - .inplace_start("to_string") - .call_arg(attr->name) - .inplace_end(); - } - stack_.func_call("append", "", "attr_strs") - .call_arg("name") - .func_call("extend", "", "attr_strs") - .call_arg(DocUtils::ToAttrAccess("self", "layouts")) - .line() - .func_call(plugin->name + "." + plugin->name, "self._inner_class", "torch.classes") - .call_arg("attr_strs") - .func_end(); - // forward method - stack_.func_def("forward", "List[torch.Tensor]").func_arg("self", "torch.nn.Module"); - for (const auto& t : plugin->inputs) { - stack_.func_arg(t->name, "torch.Tensor"); - } - stack_.func_start() - .func_call(plugin->name + "." + entry_name, "outputs", "torch.ops") - .call_arg("self._inner_class"); - for (const auto& t : plugin->inputs) { - stack_.call_arg(t->name); - } - for (const auto& a : plugin->attrs) { - stack_.call_arg(DocUtils::ToAttrAccess("self", a->name)); - } - stack_.call_arg(DocUtils::ToAttrAccess("self", "name")); - if (plugin->outputs.size() == 1) { - stack_.func_end(DocUtils::ToIndex("outputs", 0)); - } else { - stack_.func_end("outputs"); - } - // end of inner class - stack_.class_end(); - stack_.func_call(plugin->name, "op"); - for (const auto& attr : plugin->attrs) { - stack_.call_arg(attr->name); - } - stack_.call_arg("name").call_arg("layouts").func_end("op").comment(GetPyComment(plugin), true); -} - -void TorchPluginCodeGen::CodeGenConvertDepends() { - BasePluginCodeGen::CodeGenConvertDepends(); - stack_.line("from torch import fx") - .line("from tvm.relax.frontend.torch.fx_translator import TorchFXImporter") - .line(); -} - -const ffi::String TorchPluginCodeGen::CodeGenOpConvert(const Plugin& plugin) { - stack_.func_def(ConverterName(plugin), "relax.Var") - .func_arg("node", "fx.node.Node") - .func_arg("ctx", "TorchFXImporter") - .func_start() - .func_call("retrieve_args", "args", "ctx") - .call_arg("node"); - ffi::Array args; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const auto& tensor = plugin->inputs[i]; - stack_.assign(tensor->name, DocUtils::ToIndex("args", i + 1)); - args.push_back(tensor->name); - } - for (size_t i = 0; i < plugin->attrs.size(); i++) { - const auto& attr = plugin->attrs[i]; - stack_.func_call("plugin_utils.to_expr", attr->name) - .call_arg(DocUtils::ToIndex("args", i + plugin->inputs.size() + 1)); - args.push_back(attr->name); - } - stack_.assign("name", - DocUtils::ToIndex("args", 1 + plugin->inputs.size() + plugin->attrs.size())); - stack_.func_call("relax.Tuple", "args") - .call_arg(DocUtils::ToList(args)) - .func_call("InferStructInfo" + plugin->name, "out_sinfo", "_plugin_api"); - for (const auto& t : plugin->inputs) { - stack_.call_arg(t->name); - } - for (const auto& a : plugin->attrs) { - stack_.call_arg(a->name); - } - stack_.func_call("call_dps_packed", "op") - .call_arg(DocUtils::ToStr(plugin->name)) - .call_arg("args", "args") - .call_arg("list(out_sinfo)", "out_sinfo") - .func_call("msc_utils.set_expr_name", "op") - .call_arg("op") - .call_arg("name") - .func_call("emit", "var", "ctx.block_builder") - .call_arg("op") - .call_arg("name"); - if (plugin->outputs.size() == 1) { - stack_.func_end(DocUtils::ToList(ffi::Array{"var"})); - } else { - ffi::Array outputs; - for (size_t i = 0; i < plugin->outputs.size(); i++) { - const auto& tensor = plugin->outputs[i]; - stack_.func_call("relax.TupleGetItem", tensor->name).call_arg("var").call_arg(i); - outputs.push_back(tensor->name); - } - stack_.func_end(DocUtils::ToList(outputs)); - } - return EntryName(plugin); -} - -void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, - const ffi::Array& tensors, - const ffi::String& collect) { - ffi::Array call_args{"input_metas", "meta_attr_", "true"}; - stack_.line().comment("malloc " + collect).declare("std::vector", collect + "_metas"); - CodeGenSafeCall(plugin->externs["infer_" + collect], call_args, collect + "_metas"); - for (size_t i = 0; i < tensors.size(); i++) { - stack_.func_call("push_back", "", collect + "_tensors") - .inplace_start("TorchUtils::MallocTorchTensor") - .call_arg(DocUtils::ToIndex(collect + "_metas", i)); - int device_idx = plugin->FindDeviceRefIdx(tensors[i]); - if (device_idx >= 0) { - const auto& input_doc = DocUtils::ToIndex("input_tensors", device_idx); - stack_.inplace_start("device", std::nullopt, input_doc).inplace_end(); - } else { - stack_.inplace_start("TorchUtils::ToTorchDevice") - .call_arg(DocUtils::ToStr(tensors[i]->device)) - .inplace_end(); - } - stack_.inplace_end(); - } -} - -void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { - auto prepare_tensor = [this](const PluginTensor& tensor, - const ffi::Map& dtypes, size_t idx, - const ffi::String& collect) { - const ffi::String& t_name = "d_" + tensor->name; - const ffi::String& t_dtype = dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; - const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; - stack_.func_call("TorchUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) - .call_arg(DocUtils::ToIndex(collect + "_tensors", idx)) - .call_arg(DocUtils::ToIndex(collect + "_metas", idx)) - .call_arg(collect == "input"); - return t_name; - }; - - if (plugin->externs.count(device + "_compute")) { - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - ffi::Array compute_args; - ffi::String dtype_cond = ""; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - dtype_cond = dtype_cond + "input_metas[" + std::to_string(i) + - "].data_type() == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; - dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - // prepare compute datas - stack_.cond_if(dtype_cond).comment("prepare compute datas"); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); - compute_args.push_back(t_name); - } - for (size_t i = 0; i < plugin->outputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); - compute_args.push_back(t_name); - } - for (size_t i = 0; i < plugin->buffers.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->buffers[i], tensor_dtypes, i, "buffer"); - compute_args.push_back(t_name); - } - compute_args.push_back("meta_attr_"); - if (device == "cuda") { - stack_.func_call("at::cuda::getCurrentCUDAStream", - DocUtils::ToDeclare("cudaStream_t", "stream")); - compute_args.push_back("stream"); - } - CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); - stack_.cond_end(); - } - } else { - stack_.comment("Skip compute on " + device); - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.plugin.GetTorchPluginSources", - [](const ffi::String& codegen_config, const ffi::String& print_config, - const ffi::String& codegen_type) -> ffi::Map { - TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return ffi::Map(); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/plugin/torch_codegen.h b/src/contrib/msc/plugin/torch_codegen.h deleted file mode 100644 index 35bf16737dda..000000000000 --- a/src/contrib/msc/plugin/torch_codegen.h +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/torch_codegen.h - * \brief Codegen for torch plugin. - */ -#ifndef TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ -#define TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ - -#include -#include - -#include "base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for torch plugin - */ -struct TorchPluginCodeGenConfig { - bool is_training{false}; - std::string torch_prefix{"torch"}; - PLUGIN_CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("is_training")); it != obj.end()) { - is_training = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("torch_prefix")); it != obj.end()) { - torch_prefix = std::string((*it).second.cast()); - } - PLUGIN_CODEGEN_CONFIG_PARSE - } -}; - -class TorchPluginCodeGen : public BasePluginCodeGen { - public: - /*! - * \brief The constructor of TorchPluginCodeGen - * \param config the options for codegen. - */ - explicit TorchPluginCodeGen(const std::string& config = "") - : BasePluginCodeGen(config) {} - - protected: - /*! \brief Codegen plugin attr declare*/ - void CodeGenAttrDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin attr define*/ - void CodeGenAttrDefine(const Plugin& plugin) final; - - /*! \brief Codegen plugin op declare*/ - void CodeGenOpDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin op define*/ - void CodeGenOpDefine(const Plugin& plugin) final; - - /*! \brief Codegen CMake file*/ - void CodeGenCmake(const std::set& devices) final; - - /*! \brief Codegen manager depends*/ - void CodeGenManagerDepends() final; - - /*! \brief Codegen manager methods*/ - void CodeGenManagerMethods() final; - - /*! \brief Codegen manager member for plugin*/ - void CodeGenOpBuilder(const Plugin& plugin) final; - - /*! \brief Codegen convert depends*/ - void CodeGenConvertDepends() final; - - /*! \brief Codegen convert function for plugin*/ - const ffi::String CodeGenOpConvert(const Plugin& plugin) final; - - private: - /*! \brief Codegen malloc for outputs/buffers*/ - void CodeGenMalloc(const Plugin& plugin, const ffi::Array& tensors, - const ffi::String& collect); - - /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const ffi::String& device); - - /*! \brief Entry name of torch function*/ - const ffi::String EntryName(const Plugin& plugin) { - std::string lower_name; - const std::string& name = std::string(plugin->name); - for (size_t i = 0; i < name.size(); i++) { - const char& lower_c = tolower(name[i]); - if (lower_c != name[i] && i > 0) { - lower_name += "_"; - } - lower_name += lower_c; - } - return lower_name + "_entry"; - } - - /*! \brief Type name in torch*/ - const ffi::String ToTorchType(const ffi::String& type) { - if (type == "float") { - return "double"; - } - if (IsListType(type)) { - const auto& ele_type = GetEleType(type); - return "c10::arrayRef<" + ToTorchType(ele_type) + ">"; - } - return BasePluginCodeGen::ToCppType(type); - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_PLUGIN_TORCH_CODEGEN_H_ diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc deleted file mode 100644 index f33d45e16eed..000000000000 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ /dev/null @@ -1,417 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/tvm_codegen.cc - */ -#include "tvm_codegen.h" - -#include - -namespace tvm { -namespace contrib { -namespace msc { - -void TVMPluginCodeGen::CodeGenAttrDeclare(const Plugin& plugin) { - BasePluginCodeGen::CodeGenAttrDeclare(plugin); - const auto& attr_name = MetaAttrCls(plugin); - // exprs to meta_attr - stack_.comment("convert exprs to meta attrs method") - .func_def(attr_name + "_from_exprs", "const " + attr_name); - for (const auto& a : plugin->attrs) { - const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; - stack_.func_arg(a->name, "const " + anno + "&"); - } - // args to meta_attr - stack_.comment("convert args to meta attrs method") - .func_def(attr_name + "_from_args", "const " + attr_name) - .func_arg("args", "ffi::PackedArgs") - .func_arg("pos", "size_t&"); -} - -void TVMPluginCodeGen::CodeGenAttrDefine(const Plugin& plugin) { - const auto& attr_name = MetaAttrCls(plugin); - // exprs to meta_attr - stack_.func_def(attr_name + "_from_exprs", "const " + attr_name); - for (const auto& a : plugin->attrs) { - const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; - stack_.func_arg(a->name, "const " + anno + "&"); - } - stack_.func_start().declare(attr_name, "meta_attr"); - for (const auto& a : plugin->attrs) { - const ffi::String& convert = IsListType(a->type) ? "AttrFromPrims" : "AttrFromPrim"; - stack_.func_call("TVMUtils::" + convert) - .call_arg(a->name) - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); - } - stack_.func_end("meta_attr"); - // args to meta_attr - stack_.comment("convert args to meta attrs method") - .func_def(attr_name + "_from_args", "const " + attr_name) - .func_arg("args", "ffi::PackedArgs") - .func_arg("pos", "size_t&") - .func_start() - .declare(attr_name, "meta_attr"); - for (const auto& a : plugin->attrs) { - if (IsListType(a->type)) { - // TODO(meng.tong): support list atribute - LOG_FATAL << "ListType argument is not supported for tvm runtime"; - stack_.func_call("TVMUtils::AttrFromArg", a->name + "_size") - .call_arg(DocUtils::ToIndex("args", "pos")) - .func_call("TVMUtils::AttrFromArgs") - .call_arg("args") - .call_arg("pos") - .call_arg(a->name + "_size") - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) - .assign("pos", "pos + 1 + " + a->name + "_size"); - } else { - stack_.func_call("TVMUtils::AttrFromArg") - .call_arg(DocUtils::ToIndex("args", "pos")) - .call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)) - .assign("pos", "pos + 1"); - } - } - stack_.func_end("meta_attr"); -} - -void TVMPluginCodeGen::CodeGenOpDeclare(const Plugin& plugin) { - // infer struct info - stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); - for (const auto& t : plugin->inputs) { - stack_.func_arg(t->name, "const Expr&"); - } - for (const auto& a : plugin->attrs) { - const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; - stack_.func_arg(a->name, "const " + anno + "&"); - } - // infer layout - stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const ffi::Array&") - .func_arg("var_layout_map", "const VarLayoutMap&"); -} - -void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { - const auto& attr_name = MetaAttrCls(plugin); - // infer struct info - ffi::Array infer_args{"input_metas", "meta_attr", "false"}; - stack_.func_def("InferStructInfo" + plugin->name, "ffi::Array"); - for (const auto& t : plugin->inputs) { - stack_.func_arg(t->name, "const Expr&"); - } - for (const auto& a : plugin->attrs) { - const ffi::String& anno = IsListType(a->type) ? "Tuple" : "PrimValue"; - stack_.func_arg(a->name, "const " + anno + "&"); - } - stack_.func_start() - .comment("extract meta attrs") - .func_call(attr_name + "_from_exprs", DocUtils::ToDeclare("const auto&", "meta_attr")); - for (const auto& a : plugin->attrs) { - stack_.call_arg(a->name); - } - stack_.comment("extract meta inputs").declare("std::vector", "input_metas"); - for (const auto& t : plugin->inputs) { - stack_.func_call("push_back", "", "input_metas") - .inplace_start("TVMUtils::ToMetaTensor") - .call_arg(t->name) - .inplace_end(); - } - stack_.declare("std::vector", "output_metas"); - CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); - stack_.declare("ffi::Array", "output_sinfo"); - for (size_t i = 0; i < plugin->outputs.size(); i++) { - stack_.func_call("push_back", "", "output_sinfo") - .inplace_start("TVMUtils::ToTensorStructInfo") - .call_arg(DocUtils::ToIndex("output_metas", i)); - int device_idx = plugin->FindDeviceRefIdx(plugin->outputs[i]); - if (device_idx >= 0) { - stack_.call_arg(plugin->inputs[device_idx]->name); - } else { - stack_.inplace_start("TVMUtils::ToTVMDevice") - .call_arg(plugin->outputs[i]->device) - .inplace_end(); - } - stack_.inplace_end(); - } - stack_.func_end("output_sinfo"); - - // infer layout - stack_.func_def("InferLayout" + plugin->name, "InferLayoutOutput") - .func_arg("inputs", "const ffi::Array&") - .func_arg("var_layout_map", "const VarLayoutMap&") - .func_start() - .comment("define attrs"); - for (size_t i = 0; i < plugin->attrs.size(); i++) { - const auto& attr = plugin->attrs[i]; - const ffi::String& anno = IsListType(attr->type) ? "Tuple" : "PrimValue"; - stack_ - .func_call("Downcast<" + anno + ">", - DocUtils::ToDeclare("const auto&", "attr_" + attr->name)) - .call_arg(DocUtils::ToIndex("inputs", i + plugin->inputs.size())); - } - stack_.declare("ffi::Array", "arg_layouts") - .declare("ffi::Array", "output_layouts") - .comment("extract meta attrs") - .func_call(attr_name + "_from_exprs", "const " + attr_name + "& meta_attr"); - for (const auto& a : plugin->attrs) { - stack_.call_arg("attr_" + a->name); - } - stack_.comment("extract meta inputs") - .declare("std::vector", "input_metas") - .for_start("i", 0, plugin->inputs.size()) - .func_call("LayoutUtils::InferLayoutDecision", - DocUtils::ToDeclare("const auto&", "in_layout")) - .call_arg(DocUtils::ToIndex("inputs", "i")) - .call_arg("var_layout_map") - .func_call("push_back", "", "arg_layouts") - .call_arg("in_layout") - .func_call("push_back", "", "input_metas") - .inplace_start("TVMUtils::ToMetaTensor") - .call_arg(DocUtils::ToIndex("inputs", "i")) - .call_arg("in_layout") - .inplace_end() - .for_end() - .comment("add fake layout for attrs") - .for_start("i", 0, plugin->attrs.size()) - .func_call("push_back", "", "arg_layouts") - .inplace_start("LayoutDecision") - .call_arg(DocUtils::ToStr("")) - .inplace_end() - .for_end(); - stack_.declare("std::vector", "output_metas"); - CodeGenSafeCall(plugin->externs["infer_output"], infer_args, "output_metas"); - stack_.for_start("i", 0, plugin->outputs.size()) - .func_call("push_back", "", "output_layouts") - .inplace_start("LayoutDecision") - .call_arg(DocUtils::ToAttrAccess(DocUtils::ToIndex("output_metas", "i"), "layout_name()")) - .inplace_end() - .for_end() - .declare("ffi::Array", "input_layouts") - .func_call("push_back", "", "input_layouts") - .inplace_start("LayoutDecision") - .call_arg(DocUtils::ToStr("")) - .inplace_end() - .func_call("push_back", "", "input_layouts") - .call_arg("arg_layouts") - .func_call("InferLayoutOutput", DocUtils::ToDeclare("const auto&", "infer_output")) - .call_arg("input_layouts") - .call_arg("output_layouts") - .call_arg("Attrs()"); - stack_.func_end("infer_output"); - - // register funcs - stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF") - .call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name)) - .call_arg("InferStructInfo" + plugin->name) - .line() - .func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF") - .call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name)) - .call_arg("InferLayout" + plugin->name) - .line(); -} - -void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { - TVM_FFI_ICHECK(!plugin->externs.count("infer_buffer")) - << "infer_buffer is not supported for tvm runtime"; - const auto& attr_name = MetaAttrCls(plugin); - const auto& func_name = ComputeName(plugin); - ffi::String device_cond = ""; - ffi::String device_index = ""; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - ffi::String device_type = ""; - if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") { - device_type = "DLDeviceType::kDLCUDA"; - } else { - device_type = "DLDeviceType::kDLCPU"; - } - device_cond = device_cond + "TVMUtils::OnDevice(" + plugin->inputs[i]->name + ", " + - device_type + ")" + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - stack_.func_def(func_name).func_arg("args", "ffi::PackedArgs").func_arg("ret", "ffi::Any*"); - stack_.func_start().comment("define tensors"); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - stack_.assign(plugin->inputs[i]->name, DocUtils::ToIndex("args", i), "DLTensor*"); - } - stack_.comment("extract meta attrs") - .assign("pos", plugin->inputs.size(), "size_t") - .func_call(attr_name + "_from_args", "const " + attr_name + "& meta_attr") - .call_arg("args") - .call_arg("pos"); - for (size_t i = 0; i < plugin->outputs.size(); i++) { - stack_.assign(plugin->outputs[i]->name, DocUtils::ToIndex("args", "pos + " + std::to_string(i)), - "DLTensor*"); - } - stack_.comment("do the compute").cond_if(device_cond); - CodeGenCompute(plugin, "cuda"); - stack_.cond_else(); - CodeGenCompute(plugin, "cpu"); - stack_.cond_end().func_end(); - // register the compute - stack_.func_call("TVM_MSC_PLUGIN_REGISTER_GLOBAL_DEF_PACKED") - .call_arg(DocUtils::ToStr(plugin->name)) - .call_arg(func_name) - .line(); -} - -void TVMPluginCodeGen::CodeGenCmake(const std::set& devices) { - ffi::Map flags; - flags.Set("PLUGIN_SUPPORT_TVM", ""); - CodeGenPreCmake(devices, flags); - stack_.line("set(CMAKE_CXX_STANDARD 17)") - .line("set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -Wno-macro-redefined\")") - .line() - .line("set(TVM_ROOT " + config()->tvm_root + ")") - .line("find_library(TVM_LIB NAMES tvm HINTS ${TVM_ROOT}/build NO_DEFAULT_PATH)"); - ffi::Array includes, libs; - includes.push_back("${TVM_ROOT}/include"); - includes.push_back("${TVM_ROOT}/3rdparty/dlpack/include"); - includes.push_back("${TVM_ROOT}/3rdparty/compiler-rt"); - libs.push_back("${TVM_LIB}"); - CodeGenPostCmake(devices, includes, libs); -} - -void TVMPluginCodeGen::CodeGenManagerDepends() { - BasePluginCodeGen::CodeGenManagerDepends(); - stack_.line("from tvm import relax") - .line("from tvm.relax import call_dps_packed") - .line("from tvm.contrib.msc.plugin import utils as plugin_utils") - .line("from tvm.contrib.msc.core import utils as msc_utils") - .line(); -} - -void TVMPluginCodeGen::CodeGenManagerMethods() { - BasePluginCodeGen::CodeGenManagerMethods(); - stack_.func_def("setup") - .func_arg("self", "object") - .func_start() - .for_start("lib", "os.listdir(self._lib_folder)") - .assign("lib_file", "os.path.join(self._lib_folder, lib)") - .func_call("CDLL", "", "ctypes") - .call_arg("lib_file") - .for_end() - .line("from tvm.contrib.msc.plugin.op import _ffi_api") - .assign(DocUtils::ToAttrAccess("self", "_ffi_api"), "_ffi_api") - .func_end(); -} - -void TVMPluginCodeGen::CodeGenOpBuilder(const Plugin& plugin) { - stack_.func_def(plugin->name).func_arg("self", "object"); - for (const auto& t : plugin->inputs) { - stack_.func_arg(t->name, "relax.Expr"); - } - for (const auto& attr : plugin->attrs) { - stack_.func_arg(attr->name, ToPyType(attr->type), attr->default_value); - } - stack_.func_arg("name", "str", "\"" + plugin->name + "\"").func_start(); - ffi::Array args; - for (const auto& t : plugin->inputs) { - args.push_back(t->name); - } - for (const auto& a : plugin->attrs) { - stack_.func_call("plugin_utils.to_expr", a->name).call_arg(a->name); - args.push_back(a->name); - } - stack_.func_call("relax.Tuple", "args") - .call_arg(DocUtils::ToList(args)) - .func_call("InferStructInfo" + plugin->name, "out_sinfo", "self._ffi_api"); - for (const auto& t : plugin->inputs) { - stack_.call_arg(t->name); - } - for (const auto& a : plugin->attrs) { - stack_.call_arg(a->name); - } - stack_.func_call("call_dps_packed", "op") - .call_arg(DocUtils::ToStr(plugin->name)) - .call_arg("args", "args") - .call_arg("list(out_sinfo)", "out_sinfo") - .func_call("msc_utils.set_expr_name", "op") - .call_arg("op") - .call_arg("name"); - stack_.func_end("op").comment(GetPyComment(plugin), true); -} - -void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& device) { - if (plugin->externs.count(device + "_compute")) { - // compute with dtype - auto prepare_tensor = [this](const PluginTensor& tensor, - const ffi::Map& dtypes, size_t idx, - const ffi::String& collect) { - const ffi::String& t_name = "d_" + tensor->name; - const ffi::String& t_dtype = - dtypes.count(tensor->name) ? dtypes[tensor->name] : tensor->dtype; - const ffi::String& tensor_type = "DataTensor<" + t_dtype + ">"; - const ffi::String& anno = collect == "input" ? "const " + tensor_type + "&" : tensor_type; - stack_.func_call("TVMUtils::To" + tensor_type, DocUtils::ToDeclare(anno, t_name)) - .call_arg(tensor->name) - .call_arg(collect == "input"); - return t_name; - }; - for (const auto& dtypes : GetDtypeMatrix(plugin)) { - const auto& tensor_dtypes = GetTensorDtypes(plugin, dtypes); - ffi::Array compute_args; - ffi::String dtype_cond = ""; - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const auto& t_name = plugin->inputs[i]->name; - dtype_cond = dtype_cond + "TVMUtils::ToMetaType(" + t_name + - "->dtype) == DataUtils::ToMetaType(\"" + dtypes.at(i) + "\")"; - dtype_cond = dtype_cond + (i == plugin->inputs.size() - 1 ? "" : " && "); - } - // prepare compute datas - stack_.cond_if(dtype_cond).comment("prepare compute datas"); - for (size_t i = 0; i < plugin->inputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->inputs[i], tensor_dtypes, i, "input"); - compute_args.push_back(t_name); - } - for (size_t i = 0; i < plugin->outputs.size(); i++) { - const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); - compute_args.push_back(t_name); - } - TVM_FFI_ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; - compute_args.push_back("meta_attr"); - if (device == "cuda") { - // TODO(tvm-team): update to support get stream from device id - stack_.assign("stream", "TVMFFIEnvGetStream(kDLCUDA, 0)", "auto"); - compute_args.push_back("stream"); - } - CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); - stack_.cond_end(); - } - } else { - stack_.comment("Skip compute on " + device); - } -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("msc.plugin.GetTVMPluginSources", - [](const ffi::String& codegen_config, const ffi::String& print_config, - const ffi::String& codegen_type) -> ffi::Map { - TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); - if (codegen_type == "build") { - return codegen.GetBuildSources(print_config); - } - if (codegen_type == "manager") { - return codegen.GetManagerSources(print_config); - } - return ffi::Map(); - }); -} - -} // namespace msc -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/msc/plugin/tvm_codegen.h b/src/contrib/msc/plugin/tvm_codegen.h deleted file mode 100644 index 5311cd86325d..000000000000 --- a/src/contrib/msc/plugin/tvm_codegen.h +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/contrib/msc/plugin/tvm_codegen.h - * \brief Codegen for tvm plugin. - */ -#ifndef TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ -#define TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ - -#include -#include - -#include "base_codegen.h" -#include "codegen_utils.h" - -namespace tvm { -namespace contrib { -namespace msc { - -/*! - * \brief CodeGen config for tvm plugin - */ -struct TVMPluginCodeGenConfig { - bool as_relay{false}; - std::string tvm_root{"tvm"}; - PLUGIN_CODEGEN_CONFIG_MEMBERS - void Load(ffi::json::Object obj) { - if (auto it = obj.find(ffi::String("as_relay")); it != obj.end()) { - as_relay = (*it).second.cast(); - } - if (auto it = obj.find(ffi::String("tvm_root")); it != obj.end()) { - tvm_root = std::string((*it).second.cast()); - } - PLUGIN_CODEGEN_CONFIG_PARSE - } -}; - -class TVMPluginCodeGen : public BasePluginCodeGen { - public: - /*! - * \brief The constructor of TVMPluginCodeGen - * \param config the options for codegen. - */ - explicit TVMPluginCodeGen(const std::string& config = "") - : BasePluginCodeGen(config) {} - - protected: - /*! \brief Codegen plugin attr declare*/ - void CodeGenAttrDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin attr define*/ - void CodeGenAttrDefine(const Plugin& plugin) final; - - /*! \brief Codegen plugin op declare*/ - void CodeGenOpDeclare(const Plugin& plugin) final; - - /*! \brief Codegen plugin op define*/ - void CodeGenOpDefine(const Plugin& plugin) final; - - /*! \brief Codegen plugin runtime*/ - void CodeGenOpRuntime(const Plugin& plugin) final; - - /*! \brief Codegen CMake file*/ - void CodeGenCmake(const std::set& devices) final; - - /*! \brief Codegen manager depends*/ - void CodeGenManagerDepends() final; - - /*! \brief Codegen manager methods*/ - void CodeGenManagerMethods() final; - - /*! \brief Codegen manager member for plugin*/ - void CodeGenOpBuilder(const Plugin& plugin) final; - - private: - /*! \brief Func name of compute*/ - const ffi::String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; } - - /*! \brief Codegen compute*/ - void CodeGenCompute(const Plugin& plugin, const ffi::String& device); - - /*! \brief Type name in tvm*/ - const ffi::String ToTVMType(const ffi::String& type) { - if (type == "string") { - return "StringImm"; - } - if (StringUtils::StartsWith(type, "float")) { - return "FloatImm"; - } - if (type == "bool" || StringUtils::StartsWith(type, "int")) { - return "IntImm"; - } - if (IsListType(type)) { - return "Tuple"; - } - return BasePluginCodeGen::ToCppType(type); - } -}; - -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_ diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc deleted file mode 100644 index e44ed1df0390..000000000000 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ /dev/null @@ -1,364 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/contrib/tensorrt/tensorrt_runtime.cc - * \brief JSON runtime implementation for TensorRT. - */ - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "../json/json_runtime.h" - -#ifdef TVM_GRAPH_EXECUTOR_TENSORRT -#include "../../../runtime/cuda/cuda_common.h" -#include "../tensorrt/tensorrt_logger.h" -#include "../tensorrt/tensorrt_utils.h" -#endif - -namespace tvm { -namespace runtime { -namespace contrib { - -using namespace tvm::runtime::json; - -#ifdef TVM_GRAPH_EXECUTOR_TENSORRT -using namespace nvinfer1; -#endif - -class MSCTensorRTRuntime : public JSONRuntimeBase { - public: - /*! - * \brief The MSC TensorRT runtime module. Deserialize the provided functions - * on creation and store in the layer cache. - * - * \param symbol_name The name of the function. - * \param graph_json serialized JSON representation of a sub-graph. - * \param const_names The names of each constant in the sub-graph. - */ - explicit MSCTensorRTRuntime(const std::string& symbol_name, const std::string& graph_json, - const ffi::Array& const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} - - ~MSCTensorRTRuntime() { - VLOG(1) << "Destroying MSC TensorRT runtime"; - DestroyEngine(); - } - - /*! - * \brief The type key of the module. - * - * \return module type key. - */ - const char* kind() const final { return "msc_tensorrt"; } - - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { - return ffi::Module::kBinarySerializable | ffi::Module::kRunnable; - } - - /*! - * \brief Initialize runtime. - * - * \param consts The constant params from compiled model. - */ - void Init(const ffi::Array& consts) override { - TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) - << "The number of input constants must match the number of required."; - LoadGlobalOptions(); - for (size_t nid = 0; nid < nodes_.size(); nid++) { - for (size_t oid = 0; oid < nodes_[nid].GetNumOutput(); oid++) { - const auto& t_name = nodes_[nid].GetOpName() + ":" + std::to_string(oid); - tensor_ids_[t_name] = std::make_pair(nid, oid); - } - } - LoadEngine(engine_file_); - } - - void LoadGlobalOptions() { - // These settings are global to the entire subgraph. Codegen will add them as attributes to all - // op nodes. Read from first one. - for (size_t i = 0; i < nodes_.size(); ++i) { - if (nodes_[i].HasAttr("msc_global_options_num")) { - engine_file_ = std::string(nodes_[i].GetAttr("msc_global_engine")); - graph_name_ = std::string(nodes_[i].GetAttr("msc_global_graph_name")); - if (nodes_[i].HasAttr("msc_global_tool_tag")) { - tool_tag_ = std::string(nodes_[i].GetAttr("msc_global_tool_tag")); - } else { - tool_tag_ = ""; - } - } - } - } - -#ifdef TVM_GRAPH_EXECUTOR_TENSORRT - void Run() override { - SetInputOutputBinds(); - if (tool_tag_.size() > 0) { - const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); - TVM_FFI_ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - ffi::Map input_datas; - int device_id = 0; - for (const auto& pair : input_bindings_) { - const auto& tensor_name = engine_->getBindingName(pair.first); - input_datas.Set(tensor_name, device_buffers_[pair.first]); - device_id = data_entry_[pair.first]->device.device_id; - } - ffi::Map> context; - context.Set("datas", input_datas); - (*pf)(context, "before_forward", graph_name_, tool_tag_); - } - auto tvm_stream = TVMFFIEnvGetStream(kDLCUDA, device_id); -#if TRT_VERSION_GE(6, 0, 1) - TVM_FFI_ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) - << "Running TensorRT failed."; -#else - LOG_FATAL << "Only support tensorrt with version >=6.0.0"; -#endif - // Copy outputs from GPU buffers if needed. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto nid = outputs_[i].id_; - uint32_t eid = EntryID(outputs_[i]); - const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); - int binding_index = engine_->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); - if (data_entry_[eid]->device.device_type != kDLCUDA || tool_tag_.size() > 0) { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); - device_buffer.CopyTo(const_cast(data_entry_[eid])); - } - } - if (tool_tag_.size() > 0) { - const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); - TVM_FFI_ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; - ffi::Map output_datas; - for (int bid = 0; bid < engine_->getNbBindings(); bid++) { - if (input_bindings_.count(bid)) { - continue; - } - const auto& tensor_name = engine_->getBindingName(bid); - output_datas.Set(tensor_name, device_buffers_[bid]); - } - ffi::Map> context; - context.Set("datas", output_datas); - (*pf)(context, "after_forward", graph_name_, tool_tag_); - } - } - - bool LoadEngine(const ffi::String& engine_file) { - IRuntime* runtime = createInferRuntime(logger_); - // build engine - std::ifstream input(engine_file_, std::ifstream::binary); - if (!input.is_open() || !input.good()) { - LOG_ERROR << "Failed to open engine file " << engine_file_; - return false; - } - std::vector stream; - size_t size = 0; - input.seekg(0, input.end); - size = input.tellg(); - input.seekg(0, input.beg); - stream.resize(size); - input.read(stream.data(), size); - input.close(); - -#if TRT_VERSION_GE(8, 0, 0) - engine_ = runtime->deserializeCudaEngine(stream.data(), size); -#else - engine_ = runtime->deserializeCudaEngine(stream.data(), size, nullptr); -#endif - if (!engine_) { - LOG_ERROR << "Failed to load engine"; - return false; - } - // create context - context_ = engine_->createExecutionContext(); - if (!context_) { - LOG_ERROR << "Failed to create context"; - return false; - } - // resize bindings - size_t num_binding = static_cast(engine_->getNbBindings()); - bindings_.resize(num_binding); - binding_sizes_.resize(num_binding); - for (size_t i = 0; i < num_binding; i++) { - bindings_[i] = nullptr; - binding_sizes_[i] = 0; - } - // destroy runtime -#if TRT_VERSION_GE(8, 0, 0) - delete runtime; -#else - runtime->destroy(); -#endif - return true; - } - - void DestroyEngine() { -#if TRT_VERSION_GE(8, 0, 0) - delete context_; - delete engine_; -#else - context_->destroy(); - engine_->destroy(); -#endif - engine_ = nullptr; - context_ = nullptr; - } - - void SetInputOutputBinds() { - // Setup input bindings - std::set binded; - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto nid = input_nodes_[i]; - if (nodes_[nid].GetOpType() == "input") { - for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { - uint32_t eid = EntryID(nid, j); - const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(j); - int binding_index = engine_->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); -#if TRT_VERSION_GE(6, 0, 1) - std::vector shape(data_entry_[eid]->shape, - data_entry_[eid]->shape + data_entry_[eid]->ndim); - TVM_FFI_ICHECK(context_->setBindingDimensions(binding_index, VectorToTrtDims(shape))); -#endif - if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { - bindings_[binding_index] = data_entry_[eid]->data; - } else { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); - device_buffer.CopyFrom(data_entry_[eid]); - bindings_[binding_index] = device_buffer->data; - } - auto dims = engine_->getBindingDimensions(binding_index); - int num_elements = 1; - for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i]; - binding_sizes_[binding_index] = num_elements; - input_bindings_[binding_index] = eid; - binded.insert(binding_index); - } - } - } - // Setup output bindings. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto nid = outputs_[i].id_; - uint32_t eid = EntryID(outputs_[i]); - const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); - int binding_index = engine_->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); - if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { - bindings_[binding_index] = data_entry_[eid]->data; - } else { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); - bindings_[binding_index] = device_buffer->data; - } - output_bindings_[binding_index] = eid; - binded.insert(binding_index); - } - // Setup tool bindings - for (int bid = 0; bid < engine_->getNbBindings(); bid++) { - if (binded.count(bid)) { - continue; - } - if (!device_buffers_.count(bid)) { - const auto& tensor_name = engine_->getBindingName(bid); - TVM_FFI_ICHECK(tensor_ids_.count(tensor_name)) - << "Can not find tensor_name " << tensor_name; - const auto& pair = tensor_ids_[tensor_name]; - auto shape = nodes_[pair.first].GetOpShape()[pair.second]; - auto dtype = nodes_[pair.first].GetOpDataType()[pair.second]; - device_buffers_[bid] = runtime::Tensor::Empty(shape, dtype, {kDLCUDA, 0}); - } - bindings_[bid] = device_buffers_[bid]->data; - binded.insert(bid); - } - } - - Tensor GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { - std::vector shape(data_entry_[entry_id]->shape, - data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); - if (device_buffers_.count(binding_index)) { - // Buffer is already initialized. - if (shape[0] > device_buffers_[binding_index]->shape[0]) { - // Buffer is too small. Need to allocate bigger buffer. - device_buffers_[binding_index] = - runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); - } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { - // Buffer is too large. Create view. - return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); - } - } else { - // Buffer not initialized yet. - device_buffers_[binding_index] = - runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); - } - return device_buffers_.at(binding_index); - } - -#else // TVM_GRAPH_EXECUTOR_TENSORRT - void Run() override { - TVM_FFI_THROW(InternalError) << "TensorRT runtime is not enabled. " - << "Please build with USE_TENSORRT_RUNTIME."; - } - - bool LoadEngine(const ffi::String& engine_file) { return false; } - - void DestroyEngine() {} -#endif // TVM_GRAPH_EXECUTOR_TENSORRT - - private: - ffi::String engine_file_; - ffi::String tool_tag_; - ffi::String graph_name_; - std::unordered_map> tensor_ids_; -#ifdef TVM_GRAPH_EXECUTOR_TENSORRT - TensorRTLogger logger_; - ICudaEngine* engine_{nullptr}; - IExecutionContext* context_{nullptr}; - std::unordered_map input_bindings_; - std::unordered_map output_bindings_; - std::vector bindings_; - std::vector binding_sizes_; - std::unordered_map device_buffers_; -#endif -}; - -ffi::Module MSCTensorRTRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, - const ffi::Array& const_names) { - auto n = ffi::make_object(symbol_name, graph_json, const_names); - return ffi::Module(n); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("runtime.msc_tensorrt_runtime_create", MSCTensorRTRuntimeCreate) - .def("ffi.Module.load_from_bytes.msc_tensorrt", - JSONRuntimeBase::LoadFromBytes); -} - -} // namespace contrib -} // namespace runtime -} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index c7f740b9b4ed..b7c05844d0a6 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -299,7 +299,6 @@ TVM_DLL ffi::Map GetLibInfo() { {"USE_CLML", TVM_INFO_USE_CLML}, {"TVM_CLML_VERSION", TVM_INFO_USE_TVM_CLML_VERSION}, {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR}, - {"USE_MSC", TVM_INFO_USE_MSC}, {"USE_CCACHE", TVM_INFO_USE_CCACHE}, {"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM}, {"USE_NNAPI_CODEGEN", TVM_INFO_USE_NNAPI_CODEGEN}, diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py deleted file mode 100644 index 8c933b18dced..000000000000 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ /dev/null @@ -1,2636 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name - -"""Test graph builder && graph.""" - -import pytest -import torch -from torch.nn import Module - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.torch.frontend import translate - - -def verify_model(torch_model, input_info, expected): - graph, _ = translate.from_torch(torch_model, input_info) - inspect = graph.inspect() - assert msc_utils.dict_equal(inspect, expected), ( - f"Inspect {inspect} mismatch with expected {expected}" - ) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_conv1d(dynamic: bool): - """test graph builder for conv1d""" - - class Conv1D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv1D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], - "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], - "nodes": {"total": 2, "input": 1, "nn.conv1d": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10], "float32")] - verify_model(Conv1D1(), input_info, expected1) - verify_model(Conv1D2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_conv2d(dynamic: bool): - """test graph builder for conv2d""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - { - "name": "conv2d", - "shape": [bz, 6, 4, 4], - "dtype": "float32", - "layout": "NCHW", - } - ], - "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "conv2d", "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.conv2d": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info, expected1) - verify_model(Conv2D2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_linear(dynamic: bool): - """test graph builder for linear""" - - class Dense1(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=True) - - def forward(self, data): - return self.linear(data) - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - - bz = "bz" if dynamic else 1 - mdim = "mdim" if dynamic else 10 - ndim = "ndim" if dynamic else 20 - kdim = "kdim" if dynamic else 30 - - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - { - "name": "matmul", - "shape": [bz, 3, 10, 7], - "dtype": "float32", - "layout": "NCHW", - } - ], - "nodes": {"total": 2, "input": 1, "msc.linear_bias": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "matmul", "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "msc.linear": 1}, - } - expected3 = { - "inputs": [ - {"name": "inp_0", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, - ], - "outputs": [{"name": "matmul", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], - "nodes": {"total": 3, "input": 2, "matmul": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - expected3["prims"] = {"total": 3, "shape": 3} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info, expected1) - verify_model(Dense2(), input_info, expected2) - verify_model(MatMul1(), [([mdim, kdim], "float32"), ([kdim, ndim], "float32")], expected3) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_bmm(dynamic: bool): - """test graph builder for bmm""" - - class BMM(Module): - def forward(self, x, y): - return torch.bmm(x, y) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, - ], - "outputs": [ - {"name": "matmul", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} - ], - "nodes": {"total": 3, "input": 2, "matmul": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [((bz, 128, 256), "float32"), ((bz, 256, 512), "float32")] - verify_model(BMM(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_baddbmm(dynamic: bool): - """test graph builder for baddbmm""" - - class BAddBMM1(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y) - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, - ], - "outputs": [{"name": "add", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}], - "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, - ], - "outputs": [ - {"name": "multiply", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} - ], - "nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1, "multiply": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [ - ((bz, 128, 512), "float32"), - ((bz, 128, 256), "float32"), - ((bz, 256, 512), "float32"), - ] - verify_model(BAddBMM1(), input_info, expected1) - verify_model(BAddBMM2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_relu(dynamic: bool): - """test graph builder for relu""" - - class ReLU(Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, data): - return self.relu(data) - - class ReLU1(Module): - def forward(self, data): - return torch.nn.functional.relu(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "relu", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "nn.relu": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 10], "float32")] - verify_model(ReLU(), input_info, expected) - verify_model(ReLU1(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_relu6(dynamic: bool): - """test graph builder for relu6""" - - class ReLU6(Module): - def __init__(self): - super().__init__() - self.relu6 = torch.nn.ReLU6() - - def forward(self, data): - return self.relu6(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "clip": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 10], "float32")] - verify_model(ReLU6(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_maxpool2d(dynamic: bool): - """test graph builder for maxpool2d""" - - class MaxPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d3(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) - - def forward(self, data): - return self.pool(data) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [bz, 3, 4, 4], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - expected3 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - expected3["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info, expected1) - verify_model(MaxPool2d2(), input_info, expected2) - verify_model(MaxPool2d3(), input_info, expected3) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_avgpool2d(dynamic: bool): - """test graph builder for avgpool2d""" - - class AvgPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class AvgPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) - - def forward(self, data): - return self.pool(data) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "avg_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "avg_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info, expected1) - verify_model(AvgPool2d2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_adaptive_avgpool2d(dynamic: bool): - """test graph builder for adaptive_avgpool2d""" - - class AdaptiveAvgPool2d0(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) - - def forward(self, data): - return self.pool(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - { - "name": "adaptive_avg_pool2d", - "shape": [bz, 3, 10, 10], - "dtype": "float32", - "layout": "NCHW", - } - ], - "nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_flatten(dynamic: bool): - """test graph builder for flatten""" - - class Flatten(Module): - def __init__(self): - super().__init__() - self.f = torch.nn.Flatten(2, -1) - - def forward(self, data): - return self.f(data) - - bz = "bz" if dynamic else 1 - dim = "dim" if dynamic else 10 - out_dim = "MUL_3" if dynamic else 100 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, dim], "dtype": "float32", "layout": ""}], - "outputs": [ - {"name": "reshape", "shape": [bz, 3, out_dim], "dtype": "float32", "layout": ""} - ], - "nodes": {"total": 2, "input": 1, "reshape": 1}, - } - if dynamic: - expected["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1} - - input_info = [([bz, 3, 10, dim], "float32")] - verify_model(Flatten(), input_info, expected) - verify_model(torch.nn.Flatten(2, -1), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_batchnorm2d(dynamic: bool): - """test graph builder for batchnorm2d""" - - class BatchNorm2d(Module): - def __init__(self): - super().__init__() - self.batchnorm = torch.nn.BatchNorm2d(3) - - def forward(self, data): - return self.batchnorm(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - { - "name": "batch_norm.0", - "shape": [bz, 3, 10, 10], - "dtype": "float32", - "layout": "NCHW", - } - ], - "nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(BatchNorm2d(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_embedding(dynamic: bool): - """test graph builder for embedding""" - - class Embedding(Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(10, 3) - - def forward(self, data): - return self.embedding(data) - - vocab = "vocab" if dynamic else 4 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [vocab], "dtype": "int64", "layout": "A"}], - "outputs": [{"name": "take", "shape": [vocab, 3], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [vocab, 5], "dtype": "int64", "layout": "AB"}], - "outputs": [ - { - "name": "take", - "shape": [vocab, 5, 3], - "dtype": "float32", - "layout": "" if dynamic else "CBA", - } - ], - "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - - verify_model(Embedding(), [([vocab], "int64")], expected1) - verify_model(Embedding(), [([vocab, 5], "int64")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_dropout(dynamic: bool): - """test graph builder for dropout""" - - class Dropout1(Module): - def __init__(self): - super().__init__() - self.dropout = torch.nn.Dropout(0.5) - - def forward(self, data): - return self.dropout(data) - - class Dropout2(Module): - def forward(self, data): - return torch.dropout(data, 0.5, train=True) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 1, "input": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Dropout1(), input_info, expected) - verify_model(Dropout2(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_layernorm(dynamic: bool): - """test graph builder for layernorm""" - - class LayerNorm(Module): - def __init__(self): - super().__init__() - self.layernorm = torch.nn.LayerNorm((10, 10)) - - def forward(self, data): - return self.layernorm(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_functional_layernorm(dynamic: bool): - """test graph builder for functional_layernorm""" - - class LayerNorm(Module): - def __init__(self, shape): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(shape)) - self.bias = torch.nn.Parameter(torch.zeros(shape)) - - def forward(self, data): - return torch.nn.functional.layer_norm( - data, self.weight.shape, self.weight, self.bias, 1e-5 - ) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(LayerNorm((10, 10)), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_cross_entropy(dynamic: bool): - """test graph builder for cross_entropy""" - - class CrossEntropy1(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss() - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy2(Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones((2,))) - self.loss = torch.nn.CrossEntropyLoss(weight=self.weight) - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy3(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum") - - def forward(self, logits, targets): - return self.loss(logits, targets) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1}, - } - expected3 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - expected3["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 2], "float32"), ([bz], "int32")] - verify_model(CrossEntropy1(), input_info, expected1) - verify_model(CrossEntropy2(), input_info, expected2) - verify_model(CrossEntropy3(), input_info, expected3) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_functional_cross_entropy(dynamic: bool): - """test graph builder for functional_cross_entropy""" - - class CrossEntropy(Module): - def forward(self, logits, targets): - return torch.nn.functional.cross_entropy(logits, targets) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 10], "float32"), ([bz], "int32")] - verify_model(CrossEntropy(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_silu(dynamic: bool): - """test graph builder for silu""" - - class SiLU(Module): - def __init__(self): - super().__init__() - self.silu = torch.nn.SiLU() - - def forward(self, data): - return self.silu(data) - - class SiLU2(Module): - def forward(self, data): - return torch.nn.functional.silu(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "silu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "nn.silu": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info, expected) - verify_model(SiLU2(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_groupnorm(dynamic: bool): - """test graph builder for groupnorm""" - - class GroupNorm(Module): - def __init__(self): - super().__init__() - self.groupnorm = torch.nn.GroupNorm(3, 3) - - def forward(self, data): - return self.groupnorm(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "group_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.group_norm": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_softmax(dynamic: bool): - """test graph builder for softmax""" - - class Softmax(Module): - def __init__(self): - super().__init__() - self.softmax = torch.nn.Softmax(dim=1) - - def forward(self, data): - return self.softmax(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "softmax", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "nn.softmax": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_binary(dynamic: bool): - """test graph builder for binary""" - - bz = "bz" if dynamic else 1 - input_info1 = [([bz, 3, 10, 10], "float32"), ([bz, 3, 10, 10], "float32")] - input_info2 = [([bz, 3, 10, 10], "float32")] - - # Add - class Add1(Module): - def forward(self, lhs, rhs): - return lhs + rhs - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - - expected_add1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 2, "add": 1}, - } - expected_add2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "add": 1}, - } - if dynamic: - expected_add1["prims"] = {"total": 1, "shape": 1} - expected_add2["prims"] = {"total": 1, "shape": 1} - - verify_model(Add1(), input_info1, expected_add1) - verify_model(Add2(), input_info2, expected_add2) - - # Sub - class Sub1(Module): - def forward(self, lhs, rhs): - return lhs - rhs - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - - expected_sub1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 2, "subtract": 1}, - } - expected_sub2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1}, - } - if dynamic: - expected_sub1["prims"] = {"total": 1, "shape": 1} - expected_sub2["prims"] = {"total": 1, "shape": 1} - - verify_model(Sub1(), input_info1, expected_sub1) - verify_model(Sub2(), input_info2, expected_sub2) - - # Mul - class Mul1(Module): - def forward(self, lhs, rhs): - return lhs * rhs - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - - expected_mul1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 2, "multiply": 1}, - } - expected_mul2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1}, - } - if dynamic: - expected_mul1["prims"] = {"total": 1, "shape": 1} - expected_mul2["prims"] = {"total": 1, "shape": 1} - - verify_model(Mul1(), input_info1, expected_mul1) - verify_model(Mul2(), input_info2, expected_mul2) - - # True div - class TrueDiv1(Module): - def forward(self, lhs, rhs): - return lhs / rhs - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - - expected_div1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 2, "divide": 1}, - } - expected_div2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1}, - } - if dynamic: - expected_div1["prims"] = {"total": 1, "shape": 1} - expected_div2["prims"] = {"total": 1, "shape": 1} - - verify_model(TrueDiv1(), input_info1, expected_div1) - verify_model(TrueDiv2(), input_info2, expected_div2) - - # Floor div - class FloorDiv1(Module): - def forward(self, lhs, rhs): - return lhs // rhs - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - - expected_floordiv1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - { - "name": "floor_divide", - "shape": [bz, 3, 10, 10], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 3, "input": 2, "floor_divide": 1}, - } - expected_floordiv2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - { - "name": "floor_divide", - "shape": [bz, 3, 10, 10], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1}, - } - if dynamic: - expected_floordiv1["prims"] = {"total": 1, "shape": 1} - expected_floordiv2["prims"] = {"total": 1, "shape": 1} - - verify_model(FloorDiv1(), input_info1, expected_floordiv1) - verify_model(FloorDiv2(), input_info2, expected_floordiv2) - - # Power - class Power1(Module): - def forward(self, lhs, rhs): - return lhs**rhs - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - - expected_power1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 2, "power": 1}, - } - expected_power2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 3, "input": 1, "constant": 1, "power": 1}, - } - if dynamic: - expected_power1["prims"] = {"total": 1, "shape": 1} - expected_power2["prims"] = {"total": 1, "shape": 1} - - verify_model(Power1(), input_info1, expected_power1) - verify_model(Power2(), input_info2, expected_power2) - - # LT - class LT1(Module): - def forward(self, lhs, rhs): - return lhs < rhs - - class LT2(Module): - def forward(self, lhs): - return lhs < 1.0 - - expected_lt1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], - "nodes": {"total": 3, "input": 2, "less": 1}, - } - expected_lt2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], - "nodes": {"total": 3, "input": 1, "constant": 1, "less": 1}, - } - if dynamic: - expected_lt1["prims"] = {"total": 1, "shape": 1} - expected_lt2["prims"] = {"total": 1, "shape": 1} - - verify_model(LT1(), input_info1, expected_lt1) - verify_model(LT2(), input_info2, expected_lt2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_size(dynamic: bool): - """test graph builder for size""" - - class Size(Module): - def forward(self, data): - return data.size() - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], - "nodes": {"total": 2, "input": 1, "shape": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Size(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_squeeze(dynamic: bool): - """test graph builder for squeeze""" - - class Squeeze1(Module): - def forward(self, data): - return data.squeeze(1) - - class Squeeze2(Module): - def forward(self, data): - return data.squeeze() - - bz = "bz" if dynamic else 10 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ADBC"}], - "outputs": [{"name": "squeeze", "shape": [bz, 4, 1], "dtype": "float32", "layout": "ABC"}], - "nodes": {"total": 2, "input": 1, "squeeze": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} - ], - "outputs": [{"name": "squeeze", "shape": [], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "squeeze": 1}, - "prims": {"total": 1, "shape": 1}, - } - else: - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} - ], - "outputs": [{"name": "squeeze", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "squeeze": 1}, - } - input_info = [([bz, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info, expected1) - verify_model(Squeeze2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_unsqueeze(dynamic: bool): - """test graph builder for unsqueeze""" - - class Unsqueeze1(Module): - def forward(self, data): - return data.unsqueeze(1) - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ACDE"} - ], - "outputs": [ - { - "name": "expand_dims", - "shape": [bz, 1, 3, 10, 10], - "dtype": "float32", - "layout": "ABCDE", - } - ], - "nodes": {"total": 2, "input": 1, "expand_dims": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCE"} - ], - "outputs": [ - { - "name": "expand_dims", - "shape": [bz, 3, 10, 10, 1], - "dtype": "float32", - "layout": "ABCDE", - } - ], - "nodes": {"total": 2, "input": 1, "expand_dims": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info, expected1) - verify_model(Unsqueeze2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_getattr(dynamic: bool): - """test graph builder for getattr""" - - class GetAttr1(Module): - def forward(self, data): - return data.shape - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], - "nodes": {"total": 2, "input": 1, "shape": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(GetAttr1(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_getitem(dynamic: bool): - """test graph builder for getitem""" - - class Slice1(Module): - def forward(self, x): - return x[0, 1::2, :, :3] - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - { - "name": "reshape", - "shape": ["MIN_2" if dynamic else 1, 1, 10, 3], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 16], "dtype": "float32", "layout": "AB"}], - "outputs": [ - {"name": "reshape", "shape": [bz, 1, 1, 16, 1], "dtype": "float32", "layout": "CDAEB"} - ], - "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, - } - if dynamic: - expected1["prims"] = {"total": 3, "shape": 1, "Int": 1, "Min": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(Slice1(), [([bz, 3, 10, 10], "float32")], expected1) - verify_model(Slice2(), [([bz, 16], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_unary(dynamic: bool): - """test graph builder for unary""" - - bz = "bz" if dynamic else 1 - input_info = [([bz, 3, 10, 10], "float32")] - - # sin - class Sin(Module): - def forward(self, data): - return torch.sin(data) - - expected_sin = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "sin", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "sin": 1}, - } - if dynamic: - expected_sin["prims"] = {"total": 1, "shape": 1} - - verify_model(Sin(), input_info, expected_sin) - - # cos - class Cos(Module): - def forward(self, data): - return torch.cos(data) - - expected_cos = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "cos", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "cos": 1}, - } - if dynamic: - expected_cos["prims"] = {"total": 1, "shape": 1} - - verify_model(Cos(), input_info, expected_cos) - - # exp - class Exp(Module): - def forward(self, data): - return torch.exp(data) - - expected_exp = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "exp", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "exp": 1}, - } - if dynamic: - expected_exp["prims"] = {"total": 1, "shape": 1} - - verify_model(Exp(), input_info, expected_exp) - - # sqrt - class Sqrt(Module): - def forward(self, data): - return torch.sqrt(data) - - expected_sqrt = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "sqrt", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "sqrt": 1}, - } - if dynamic: - expected_sqrt["prims"] = {"total": 1, "shape": 1} - - verify_model(Sqrt(), input_info, expected_sqrt) - - # sigmoid - class Sigmoid(Module): - def forward(self, data): - return torch.sigmoid(data) - - expected_sigmoid = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "sigmoid", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "sigmoid": 1}, - } - if dynamic: - expected_sigmoid["prims"] = {"total": 1, "shape": 1} - - verify_model(Sigmoid(), input_info, expected_sigmoid) - - # round - class Round(Module): - def forward(self, data): - return torch.round(data) - - expected_round = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "round", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "round": 1}, - } - if dynamic: - expected_round["prims"] = {"total": 1, "shape": 1} - - verify_model(Round(), input_info, expected_round) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_gelu(dynamic: bool): - """test graph builder for gelu""" - - class Gelu(Module): - def forward(self, data): - return torch.nn.functional.gelu(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "gelu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "nn.gelu": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Gelu(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_tanh(dynamic: bool): - """test graph builder for tanh""" - - class Tanh(Module): - def forward(self, data): - return torch.tanh(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "tanh", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "tanh": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_clamp(dynamic: bool): - """test graph builder for clamp""" - - class Clamp(Module): - def forward(self, data): - return torch.clamp(data, min=0.1, max=0.5) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "clip": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_interpolate(dynamic: bool): - """test graph builder for interpolate""" - - class Interpolate(Module): - def forward(self, data): - return torch.nn.functional.interpolate(data, (5, 5)) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "resize2d", "shape": [bz, 3, 5, 5], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "image.resize2d": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_addmm(dynamic: bool): - """test graph builder for addmm""" - - class Addmm(Module): - def forward(self, x_1, x_2, x_3): - return torch.addmm(x_1, x_2, x_3) - - mdim = "mdim" if dynamic else 10 - ndim = "ndim" if dynamic else 20 - kdim = "kdim" if dynamic else 30 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, - {"name": "inp_2", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, - ], - "outputs": [{"name": "add", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], - "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, - } - if dynamic: - expected["prims"] = {"total": 3, "shape": 3} - - input_info = [([mdim, ndim], "float32"), ([mdim, kdim], "float32"), ([kdim, ndim], "float32")] - verify_model(Addmm(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_split(dynamic: bool): - """test graph builder for split""" - - class Split1(Module): - def forward(self, data): - return torch.split(data, 1, dim=1) - - class Split2(Module): - def forward(self, data): - return torch.split(data, [1, 2], dim=1) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "nodes": {"total": 2, "input": 1, "split": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [bz, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "nodes": {"total": 2, "input": 1, "split": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Split1(), input_info, expected1) - verify_model(Split2(), input_info, expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_unbind(dynamic: bool): - """test graph builder for unbind""" - - class Unbind(Module): - def forward(self, data): - return torch.unbind(data, dim=1) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "tuple_0", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_1", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_2", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, - ], - "nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Unbind(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_cumsum(dynamic: bool): - """test graph builder for cumsum""" - - class Cumsum(Module): - def forward(self, data): - return torch.cumsum(data, dim=1, dtype=torch.int32) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "cumsum", "shape": [bz, 2, 3, 4], "dtype": "int32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "cumsum": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 2, 3, 4], "float32")] - verify_model(Cumsum(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_chunk(dynamic: bool): - """test graph builder for chunk""" - - class Chunk(Module): - def forward(self, data): - return torch.chunk(data, 3, dim=1) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "outputs": [ - {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - ], - "nodes": {"total": 2, "input": 1, "split": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_inplace_fill(dynamic: bool): - """test graph builder for inplace_fill""" - - class InplaceFill(Module): - def forward(self, data): - data.fill_(1.5) - return data - - bz = "bz" if dynamic else 1 - if dynamic: - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "full", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1}, - "prims": {"total": 1, "shape": 1}, - } - else: - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [bz, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - verify_model(InplaceFill(), [([bz, 10], "float32")], expected) - - -def test_arange(): - """test graph builder for arange""" - - class Arange(Module): - def forward(self): - return torch.arange(0, 20, dtype=torch.int32) - - expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [20], "dtype": "int32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - verify_model(Arange(), [([10, 10], "float32")], expected) - - -def test_empty(): - """test graph builder for empty""" - - class Empty(Module): - def forward(self): - return torch.empty((10, 10), dtype=torch.float32) - - expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - verify_model(Empty(), [([10, 10], "float32")], expected) - - -def test_tensor(): - """test graph builder for tensor""" - - class Empty1(Module): - def forward(self): - return torch.tensor(3, dtype=torch.float32) - - expected1 = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - class Empty2(Module): - def forward(self): - return torch.tensor(3) - - expected2 = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - verify_model(Empty1(), [([10, 10], "float32")], expected1) - verify_model(Empty2(), [([10, 10], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_tril(dynamic: bool): - """test graph builder for tril""" - - class Tril(Module): - def forward(self, data): - return torch.tril(data, 1) - - class InplaceTril(Module): - def forward(self, data): - data.tril_(1) - return data - - row = "row" if dynamic else 10 - col = "col" if dynamic else 10 - expected = { - "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tril", "shape": [row, col], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "tril": 1}, - } - if dynamic: - expected["prims"] = {"total": 2, "shape": 2} - - input_info = [([row, col], "float32")] - verify_model(Tril(), input_info, expected) - verify_model(InplaceTril(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_triu(dynamic: bool): - """test graph builder for triu""" - - class Triu(Module): - def forward(self, data): - return torch.triu(data, 1) - - class InplaceTriu(Module): - def forward(self, data): - data.triu_(1) - return data - - row = "row" if dynamic else 10 - col = "col" if dynamic else 10 - expected = { - "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "triu", "shape": [row, col], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "triu": 1}, - } - if dynamic: - expected["prims"] = {"total": 2, "shape": 2} - - input_info = [([row, col], "float32")] - verify_model(Triu(), input_info, expected) - verify_model(InplaceTriu(), input_info, expected) - - -def test_new_ones(): - """test graph builder for new_ones""" - - class NewOnes(Module): - def forward(self, x): - return x.new_ones(1, 2, 3) - - expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_expand(dynamic: bool): - """test graph builder for expand""" - - class Expand1(Module): - def forward(self, x): - return x.expand(4, 2, 3, 4) - - class Expand2(Module): - def forward(self, x): - return x.expand(4, -1, -1, 4) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [ - {"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype": "float32", "layout": ""} - ], - "nodes": {"total": 2, "input": 1, "broadcast_to": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 2, 3, 4], "float32")] - verify_model(Expand1(), input_info, expected) - verify_model(Expand2(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_reduce(dynamic: bool): - """test graph builder for reduce""" - - # sum - class Sum(Module): - def forward(self, x): - return torch.sum(x, (2, 1)) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ACDB"}], - "outputs": [{"name": "sum", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "sum": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz, 2, 3, 4], "float32")] - verify_model(Sum(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_datatype(dynamic: bool): - """test graph builder for datatype""" - - bz = "bz" if dynamic else 1 - input_info = [([bz, 2, 3, 4], "float32")] - - # float - class ToFloat(Module): - def forward(self, x): - return x.float() - - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], - "outputs": [ - {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - - verify_model(ToFloat(), input_info, expected1) - - # half - class ToHalf(Module): - def forward(self, x): - return x.half() - - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], - "outputs": [ - {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float16", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - if dynamic: - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(ToHalf(), input_info, expected2) - - # type - class Type(Module): - def forward(self, x): - return x.type(torch.float32) - - expected3 = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], - "outputs": [ - {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - if dynamic: - expected3["prims"] = {"total": 1, "shape": 1} - - # type - class TypeFromAttr(Module): - def forward(self, x): - return x.type(x.getattr("dtype")) - - expected4 = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], - "outputs": [ - {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - if dynamic: - expected4["prims"] = {"total": 1, "shape": 1} - - # astype - class AsType(Module): - def forward(self, x): - return x.astype(torch.float32) - - expected5 = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], - "outputs": [ - {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - if dynamic: - expected5["prims"] = {"total": 1, "shape": 1} - - verify_model(Type(), input_info, expected3) - verify_model(TypeFromAttr(), input_info, expected4) - verify_model(AsType(), input_info, expected5) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_permute(dynamic: bool): - """test graph builder for permute""" - - class Permute(Module): - def forward(self, x): - return x.permute(0, 3, 2, 1) - - bz = "bz" if dynamic else 1 - channel = "channel" if dynamic else 2 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, channel, 3, 4], "dtype": "float32", "layout": "ADCB"} - ], - "outputs": [ - { - "name": "permute_dims", - "shape": [bz, 4, 3, channel], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 2, "input": 1, "permute_dims": 1}, - } - if dynamic: - expected["prims"] = {"total": 2, "shape": 2} - - input_info = [([bz, channel, 3, 4], "float32")] - verify_model(Permute(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_reshape(dynamic: bool): - """test graph builder for reshape""" - - class Reshape(Module): - def forward(self, x): - return x.reshape(-1, 12) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [ - { - "name": "reshape", - "shape": ["MUL_2" if dynamic else 2, 12], - "dtype": "float32", - "layout": "", - } - ], - "nodes": {"total": 2, "input": 1, "reshape": 1}, - } - if dynamic: - expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - - input_info = [([bz, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_transpose(dynamic: bool): - """test graph builder for transpose""" - - class Transpose(Module): - def forward(self, x): - return x.transpose(1, 3) - - bz = "bz" if dynamic else 1 - hidden = "hidden" if dynamic else 4 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 2, 3, hidden], "dtype": "float32", "layout": "ADCB"} - ], - "outputs": [ - { - "name": "permute_dims", - "shape": [bz, hidden, 3, 2], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 2, "input": 1, "permute_dims": 1}, - } - if dynamic: - expected["prims"] = {"total": 2, "shape": 2} - - input_info = [([bz, 2, 3, hidden], "float32")] - verify_model(Transpose(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_view(dynamic: bool): - """test graph builder for view""" - - class View(Module): - def forward(self, x): - return x.view(-1, 12) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [ - { - "name": "reshape", - "shape": ["MUL_2" if dynamic else 2, 12], - "dtype": "float32", - "layout": "", - } - ], - "nodes": {"total": 2, "input": 1, "reshape": 1}, - } - if dynamic: - expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - - input_info = [([bz, 2, 3, 4], "float32")] - verify_model(View(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_keep_params(dynamic: bool): - """test graph builder for keep_params""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - { - "name": "conv2d", - "shape": [bz, 6, 4, 4], - "dtype": "float32", - "layout": "NCHW", - } - ], - "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Conv2D1(), [([bz, 3, 10, 10], "float32")], expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_unwrap_unit_return_tuple(dynamic: bool): - """test graph builder for unwrap_unit_return_tuple""" - - class Identity(Module): - def forward(self, x): - return (x,) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tuple", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "tuple": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Identity(), [([bz, 256], "float32")], expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_no_bind_return_tuple(dynamic: bool): - """test graph builder for no_bind_return_tuple""" - - class Identity(Module): - def forward(self, x, y): - return (x, y) - - bz_x = "bz" if dynamic else 1 - bz_y = "bz" if dynamic else 2 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, - ], - "outputs": [ - {"name": "tuple_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, - {"name": "tuple_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, - ], - "nodes": {"total": 3, "input": 2, "tuple": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - input_info = [([bz_x, 256], "float32"), ([bz_y, 256], "float32")] - verify_model(Identity(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_argmax(dynamic: bool): - """test graph builder for argmax""" - - class Argmax1(Module): - def forward(self, data): - return torch.argmax(data, dim=-1) - - class Argmax2(Module): - def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [bz], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmax": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [bz, 1], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmax": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(Argmax1(), [([bz, 256], "float32")], expected1) - verify_model(Argmax2(), [([bz, 256], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_argmin(dynamic: bool): - """test graph builder for argmin""" - - class Argmin1(Module): - def forward(self, data): - return torch.argmin(data) - - class Argmin2(Module): - def forward(self, data): - return torch.argmin(data, keepdim=True) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmin": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmin": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(Argmin1(), [([bz, 256], "float32")], expected1) - verify_model(Argmin2(), [([bz, 256], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_to(dynamic: bool): - """test graph builder for to""" - - class To1(Module): - def forward(self, data): - return data.to(torch.float16) - - class To2(Module): - def forward(self, data): - return data.to("cpu") - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "astype", "shape": [bz, 256], "dtype": "float16", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], - "nodes": {"total": 1, "input": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(To1(), [([bz, 256], "float32")], expected1) - verify_model(To2(), [([bz, 256], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_mean(dynamic: bool): - """test graph builder for mean""" - - class Mean(Module): - def forward(self, data): - return data.mean(-1) - - class MeanKeepDim(Module): - def forward(self, data): - return data.mean(-1, keepdim=True) - - bz = "bz" if dynamic else 1 - expected1 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "mean", "shape": [bz], "dtype": "float32", "layout": "A"}], - "nodes": {"total": 2, "input": 1, "mean": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "mean", "shape": [bz, 1], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "mean": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(Mean(), [([bz, 256], "float32")], expected1) - verify_model(MeanKeepDim(), [([bz, 256], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_rsqrt(dynamic: bool): - """test graph builder for rsqrt""" - - class Rsqrt(Module): - def forward(self, data): - return torch.rsqrt(data) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "rsqrt", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "rsqrt": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Rsqrt(), [([bz, 256], "float32")], expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_neg(dynamic: bool): - """test graph builder for neg""" - - class Neg(Module): - def forward(self, data): - return -data - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "negative", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "negative": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Neg(), [([bz, 256], "float32")], expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_max(dynamic: bool): - """test graph builder for max""" - - class Max(Module): - def forward(self, x, y): - return torch.max(x, y) - - bz = "bz" if dynamic else 1 - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, - {"name": "inp_1", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, - ], - "outputs": [{"name": "maximum", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], - "nodes": {"total": 3, "input": 2, "maximum": 1}, - } - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Max(), [([bz, 256], "float32"), ([bz, 256], "float32")], expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_cat(dynamic: bool): - """test graph builder for cat""" - - class Cat1(Module): - def forward(self, data, data1, data2): - return torch.cat((data, data1, data2), dim=1) - - class Cat2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.cat((data, const1, const2), dim=1) - - bz = "bz" if dynamic else 1 - dim = "dim" if dynamic else 3 - input_info = [ - ([bz, dim, 10, 10], "float32"), - ([bz, dim, 10, 10], "float32"), - ([bz, dim, 10, 10], "float32"), - ] - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_2", "shape": [bz, dim, 10, 10], "dtype": "float32", "layout": ""}, - ], - "outputs": [ - { - "name": "concat", - "shape": [bz, "MUL_3" if dynamic else 9, 10, 10], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 4, "input": 3, "concat": 1}, - } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [ - {"name": "concat", "shape": [1, 9, 10, 10], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 4, "input": 1, "constant": 2, "concat": 1}, - } - if dynamic: - expected1["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1} - - verify_model(Cat1(), input_info, expected1) - verify_model(Cat2(), [([1, 3, 10, 10], "float32")], expected2) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_stack(dynamic: bool): - """Test graph builder for stack.""" - - bz = "bz" if dynamic else 1 - - class Stack(Module): - def forward(self, data, data1, data2): - return torch.stack((data, data1, data2), dim=0) - - input_info = [ - ([bz, 3, 10, 10], "float32"), - ([bz, 3, 10, 10], "float32"), - ([bz, 3, 10, 10], "float32"), - ] - - expected = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_2", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}, - ], - "outputs": [ - { - "name": "stack", - "shape": [3, bz, 3, 10, 10], - "dtype": "float32", - "layout": "SABCD", - } - ], - "nodes": {"total": 4, "input": 3, "stack": 1}, - } - - if dynamic: - expected["prims"] = {"total": 1, "shape": 1} - - verify_model(Stack(), input_info, expected) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_scatter(dynamic: bool): - """test graph builder for scatter""" - - bz = "bz" if dynamic else 20 - - class Scatter1(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - return data.scatter(dim=0, index=self.index, src=src) - - class Scatter2(Module): - def forward(self, data, index, src): - return data.scatter(0, index, src) - - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}, - {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": "AB"}, - ], - "outputs": [ - {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"} - ], - "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 1}, - } - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}, - {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": "AB"}, - {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": "AB"}, - ], - "outputs": [ - {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"} - ], - "nodes": {"total": 4, "input": 3, "scatter_elements": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model(Scatter1(), [([bz, 20], "float32"), ([2, 5], "float32")], expected1) - verify_model( - Scatter2(), [([bz, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")], expected2 - ) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_masked_scatter(dynamic: bool): - """test graph builder for masked_scatter""" - - dim = "dim" if dynamic else 5 - - class MaskedScatter1(Module): - def forward(self, data, mask, src): - return data.masked_scatter(mask, src) - - class MaskedScatter2(Module): - def forward(self, data, mask, src): - return data.masked_scatter(mask, src) - - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": "A"}, - {"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"}, - {"name": "inp_2", "shape": [10], "dtype": "float32", "layout": "A"}, - ], - "outputs": [{"name": "where", "shape": [dim], "dtype": "float32", "layout": "A"}], - "nodes": { - "total": 8, - "input": 3, - "cumsum": 1, - "constant": 1, - "subtract": 1, - "take": 1, - "where": 1, - }, - } - expected2 = { - "inputs": [ - { - "name": "inp_0", - "shape": [2, dim], - "dtype": "float32", - "layout": "" if dynamic else "BA", - }, - { - "name": "inp_1", - "shape": [2, dim], - "dtype": "bool", - "layout": "" if dynamic else "BA", - }, - { - "name": "inp_2", - "shape": [3, dim], - "dtype": "float32", - "layout": "" if dynamic else "BA", - }, - ], - "outputs": [ - { - "name": "where", - "shape": [2, dim], - "dtype": "float32", - "layout": "" if dynamic else "BA", - } - ], - "nodes": { - "total": 11, - "input": 3, - "reshape": 3, - "cumsum": 1, - "constant": 1, - "subtract": 1, - "take": 1, - "where": 1, - }, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2} - - verify_model( - MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], "float32")], expected1 - ) - verify_model( - MaskedScatter2(), - [([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")], - expected2, - ) - - -@pytest.mark.parametrize("dynamic", [True, False]) -def test_attention(dynamic: bool): - """test graph builder for attention""" - - # pylint: disable=import-outside-toplevel - import torch.nn.functional as F - - seq = "seq" if dynamic else 128 - - class Attention1(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data) - - class Attention2(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True) - - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - ], - "outputs": [ - {"name": "attention", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD"} - ], - "nodes": {"total": 4, "input": 3, "msc.attention": 1}, - } - if dynamic: - expected1["prims"] = {"total": 1, "shape": 1} - - input_info = [ - ([1, 8, seq, 64], "float32"), - ([1, 8, seq, 64], "float32"), - ([1, 8, seq, 64], "float32"), - ] - verify_model(Attention1(), input_info, expected1) - verify_model(Attention2(), input_info, expected1) - - class Attention3(Module): - def forward(self, q_data, k_data, v_data, mask): - return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_3", "shape": [1, 8, seq, seq], "dtype": "float32", "layout": "ABCD"}, - ], - "outputs": [ - { - "name": "attention_bias", - "shape": [1, 8, seq, 64], - "dtype": "float32", - "layout": "ABCD", - } - ], - "nodes": {"total": 5, "input": 4, "msc.attention": 1}, - } - if dynamic: - expected2["prims"] = {"total": 1, "shape": 1} - - verify_model( - Attention3(), - [ - ([1, 8, seq, 64], "float32"), - ([1, 8, seq, 64], "float32"), - ([1, 8, seq, 64], "float32"), - ([1, 8, seq, seq], "float32"), - ], - expected2, - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py deleted file mode 100644 index e7d44fe90d89..000000000000 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ /dev/null @@ -1,216 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 - -"""Test Pipeline in MSC.""" - -import json - -import pytest -import torch - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.pipeline import MSCManager, TorchDynamic - -requires_tensorrt = pytest.mark.skipif( - tvm.get_global_func("relax.ext.tensorrt", True) is None, - reason="TENSORRT is not enabled", -) - - -def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1): - """Get msc config""" - - path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") - return { - "workspace": msc_utils.msc_dir(path, keep_history=False), - "verbose": "critical", - "model_type": model_type, - "inputs": inputs, - "outputs": outputs, - "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, - "prepare": {"profile": {"benchmark": {"repeat": 10}}}, - "baseline": { - "run_type": model_type, - "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - }, - "compile": { - "run_type": compile_type, - "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - }, - } - - -def _get_torch_model(name, training=False): - """Get model from torch vision""" - - # pylint: disable=import-outside-toplevel - try: - import torchvision - - model = getattr(torchvision.models, name)() - if training: - model = model.train() - else: - model = model.eval() - return model - except: # pylint: disable=bare-except - print("please install torchvision package") - return None - - -def _check_pipeline(pipeline, expected_info, dynamic=False): - """Check the pipeline results""" - - passed, err = True, "" - if not pipeline.report["success"]: - passed = False - err = f"Failed to run pipe for {pipeline.model_type} -> {pipeline.compile_type}" - if not dynamic: - model_info = pipeline.get_runtime().model_info - if not msc_utils.dict_equal(model_info, expected_info): - passed = False - err = f"Model info {model_info} mismatch with expected {expected_info}" - pipeline.destory() - if not passed: - raise Exception(f"{err}\nReport:{json.dumps(pipeline.report, indent=2)}") - - -def _test_from_torch( - compile_type, expected_info, training=False, dynamic=False, atol=1e-1, rtol=1e-1 -): - if dynamic and not hasattr(torch, "compile"): - return - - torch_model = _get_torch_model("resnet50", training) - if torch_model: - if torch.cuda.is_available(): - torch_model = torch_model.to(torch.device("cuda:0")) - config = _get_config( - MSCFramework.TORCH, - compile_type, - inputs=[["input_0", [1, 3, 224, 224], "float32"]], - outputs=["output"], - dynamic=dynamic, - atol=atol, - rtol=rtol, - ) - pipeline = TorchDynamic(torch_model, config) if dynamic else MSCManager(torch_model, config) - pipeline.run_pipe() - _check_pipeline(pipeline, expected_info, dynamic) - - -@pytest.mark.parametrize("dynamic", [False]) -def test_tvm_pipeline(dynamic): - """Test pipeline for tvm""" - - model_info = { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], - "nodes": { - "total": 229, - "input": 1, - "nn.conv2d": 53, - "nn.batch_norm": 53, - "get_item": 53, - "nn.relu": 49, - "nn.max_pool2d": 1, - "add": 16, - "nn.adaptive_avg_pool2d": 1, - "reshape": 1, - "msc.linear_bias": 1, - }, - } - _test_from_torch(MSCFramework.TVM, model_info, training=False, dynamic=dynamic) - - if not dynamic: - model_info = { - "inputs": [ - {"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"} - ], - "outputs": [ - { - "name": "MobilenetV2/Predictions/Reshape_1:0", - "shape": [1, 1001], - "dtype": "float32", - "layout": "NC", - } - ], - "nodes": { - "total": 138, - "input": 1, - "msc.conv2d_bias": 36, - "clip": 35, - "nn.conv2d": 17, - "nn.batch_norm": 17, - "get_item": 17, - "add": 10, - "nn.avg_pool2d": 1, - "squeeze": 1, - "reshape": 2, - "nn.softmax": 1, - }, - } - - -@pytest.mark.parametrize("dynamic", [False]) -def test_torch_pipeline(dynamic): - """Test pipeline for torch""" - - model_info = { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], - "nodes": { - "total": 229, - "input": 1, - "nn.conv2d": 53, - "nn.batch_norm": 53, - "get_item": 53, - "nn.relu": 49, - "nn.max_pool2d": 1, - "add": 16, - "nn.adaptive_avg_pool2d": 1, - "reshape": 1, - "msc.linear_bias": 1, - }, - } - _test_from_torch(MSCFramework.TORCH, model_info, training=False, dynamic=dynamic) - - -@requires_tensorrt -@pytest.mark.parametrize("dynamic", [False]) -def test_tensorrt_pipeline(dynamic): - """Test pipeline for tensorrt""" - - model_info = { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, - } - _test_from_torch(MSCFramework.TENSORRT, model_info, training=False, dynamic=dynamic) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_plugin.py b/tests/python/contrib/test_msc/test_plugin.py deleted file mode 100644 index f14d555049ec..000000000000 --- a/tests/python/contrib/test_msc/test_plugin.py +++ /dev/null @@ -1,370 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test Plugin in MSC.""" - -import numpy as np -import pytest -import torch -from torch import nn - -import tvm.testing -from tvm import relax -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.pipeline import MSCManager -from tvm.contrib.msc.plugin import build_plugins -from tvm.relax.transform import BindParams -from tvm.script import relax as R - - -def _get_externs_header(): - """Get the header source for externs""" - - return """#ifndef EXTERNS_H_ -#define EXTERNS_H_ - -#include "plugin_base.h" - -#ifdef PLUGIN_ENABLE_CUDA -#include -#endif - -namespace tvm { -namespace contrib { -namespace msc { -namespace plugin { - -template -std::vector my_relu_infer(const std::vector& inputs, const TAttr& attrs, - bool is_runtime) { - std::vector outputs; - outputs.push_back(MetaTensor(inputs[0].shape(), inputs[0].data_type(), inputs[0].layout())); - return outputs; -} - -template -void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val); - -template -void my_relu_cpu_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs) { - my_relu_cpu_kernel(input, output, T(attrs.max_val)); -} - -#ifdef PLUGIN_ENABLE_CUDA -template -void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, - const cudaStream_t& stream); - -template -void my_relu_cuda_compute(const DataTensor& input, DataTensor& output, const TAttr& attrs, - const cudaStream_t& stream) { - my_relu_cuda_kernel(input, output, T(attrs.max_val), stream); -} -#endif - -} // namespace plugin -} // namespace msc -} // namespace contrib -} // namespace tvm -#endif // EXTERNS_H_ -""" - - -def _get_externs_cc(): - """Get externs cc source""" - return """#include "externs.h" - -namespace tvm { -namespace contrib { -namespace msc { -namespace plugin { - -template -void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, T max_val) { - const T* input_data = input.const_data(); - T* output_data = output.data(); - for (size_t i = 0; i < output.size(); i++) { - if (input_data[i] >= max_val) { - output_data[i] = max_val; - } else if (input_data[i] <= 0) { - output_data[i] = 0; - } else { - output_data[i] = input_data[i]; - } - } -} - -template void my_relu_cpu_kernel(const DataTensor& input, DataTensor& output, - float max_val); - -} // namespace plugin -} // namespace msc -} // namespace contrib -} // namespace tvm -""" - - -def _get_externs_cu(): - """Get externs cu source""" - - return """#include "externs.h" - -#define CU1DBLOCK 256 -#define KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) - -namespace tvm { -namespace contrib { -namespace msc { -namespace plugin { - -inline int n_blocks(int size, int block_size) { - return size / block_size + (size % block_size == 0 ? 0 : 1); -} - -template -__global__ static void _my_relu(const T* src, T* dst, T max_val, int n) { - KERNEL_LOOP(i, n) { - if (src[i] >= max_val) { - dst[i] = max_val; - } else if (src[i] <= 0) { - dst[i] = 0; - } else { - dst[i] = src[i]; - } - } -} - -template -void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, T max_val, - const cudaStream_t& stream) { - const T* input_data = input.const_data(); - T* output_data = output.data(); - dim3 Bl(CU1DBLOCK); - dim3 Gr(n_blocks(output.size(), CU1DBLOCK)); - _my_relu<<>>(input_data, output_data, max_val, output.size()); -} - -template void my_relu_cuda_kernel(const DataTensor& input, DataTensor& output, - float max_val, const cudaStream_t& stream); - -} // namespace plugin -} // namespace msc -} // namespace contrib -} // namespace tvm -""" - - -def _create_plugin(externs_dir): - """Create sources under source folder""" - with open(externs_dir.relpath("externs.h"), "w") as f: - f.write(_get_externs_header()) - with open(externs_dir.relpath("externs.cc"), "w") as f: - f.write(_get_externs_cc()) - with open(externs_dir.relpath("externs.cu"), "w") as f: - f.write(_get_externs_cu()) - return { - "MyRelu": { - "inputs": [{"name": "input", "dtype": "T"}], - "outputs": [{"name": "output", "dtype": "T"}], - "attrs": [{"name": "max_val", "type": "float"}], - "support_dtypes": {"T": ["float"]}, - "externs": { - "infer_output": {"name": "my_relu_infer", "header": "externs.h"}, - "cpu_compute": { - "name": "my_relu_cpu_compute", - "header": "externs.h", - "source": "externs.cc", - }, - "cuda_compute": { - "name": "my_relu_cuda_compute", - "header": "externs.h", - "source": "externs.cu", - }, - }, - } - } - - -def _get_torch_model(torch_manager): - """Build model with plugin""" - - class MyModel(nn.Module): - """Test model with plugin""" - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - self.relu = torch_manager.MyRelu(max_val=0.5) - self.maxpool = nn.MaxPool2d(kernel_size=[1, 1]) - - def forward(self, data): - data = self.conv(data) - data = self.relu(data) - return self.maxpool(data) - - return MyModel() - - -def _get_tvm_model(tvm_manager): - """Build model with plugin""" - - block_builder = relax.BlockBuilder() - weights = np.random.rand(6, 3, 7, 7).astype("float32") - data = relax.Var("data", R.Tensor((1, 3, 224, 224), "float32")) - weight = relax.Var("weight", R.Tensor(weights.shape, weights.dtype.name)) - inputs = [data, weight] - with block_builder.function(name="main", params=inputs.copy()): - with block_builder.dataflow(): - data = relax.op.nn.conv2d(data, weight) - data = block_builder.emit(data, "conv2d") - data = tvm_manager.MyRelu(data, max_val=0.5) - data = block_builder.emit(data, "relu") - data = relax.op.nn.max_pool2d(data) - data = block_builder.emit(data, "max_pool2d") - data = block_builder.emit_output(data) - block_builder.emit_func_output(data) - mod = block_builder.finalize() - return BindParams("main", {"weight": tvm.runtime.tensor(weights)})(mod) - - -def _build_plugin(frameworks, plugin_root): - externs_dir = plugin_root.create_dir("externs") - install_dir = plugin_root.create_dir("install") - plugin = _create_plugin(externs_dir) - managers = build_plugins(plugin, frameworks, install_dir, externs_dir=externs_dir) - return managers - - -def _run_relax(relax_mod, target_name, data): - target = tvm.target.Target(target_name) - relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) - if target_name == "cuda": - with target: - relax_mod = tvm.s_tir.transform.DefaultGPUSchedule()(relax_mod) - device = tvm.cuda() - else: - device = tvm.cpu() - with tvm.transform.PassContext(opt_level=3): - relax_exec = tvm.compile(relax_mod, target) - runnable = tvm.relax.VirtualMachine(relax_exec, device) - data = tvm.runtime.tensor(data, device) - return runnable["main"](data).numpy() - - -def _test_tvm_plugin(manager, target): - """Test plugin in tvm""" - - model = _get_tvm_model(manager) - data = np.random.rand(1, 3, 224, 224).astype("float32") - outputs = _run_relax(model, target, data) - assert outputs.min() >= 0 and outputs.max() <= 0.5 - - -def _test_torch_plugin(manager): - """Test plugin in torch""" - - model = _get_torch_model(manager) - torch_data = torch.from_numpy(np.random.rand(1, 3, 224, 224).astype("float32")) - if torch.cuda.is_available(): - model = model.to(torch.device("cuda:0")) - torch_data = torch_data.to(torch.device("cuda:0")) - outputs = model(torch_data) - assert outputs.min() >= 0 and outputs.max() <= 0.5 - - -def _test_with_manager(plugins, compile_type, expected_info): - """Test the plugin with manager""" - - path = "test_plugin_" + compile_type - model = _get_torch_model(plugins[MSCFramework.TORCH]) - if torch.cuda.is_available(): - model = model.to(torch.device("cuda:0")) - config = { - "workspace": msc_utils.msc_dir(path), - "model_type": MSCFramework.TORCH, - "verbose": "critical", - "inputs": [["input_0", [1, 3, 224, 224], "float32"]], - "outputs": ["output"], - "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, - "prepare": {"profile": {"benchmark": {"repeat": 10}}}, - "baseline": { - "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, - }, - "compile": { - "run_type": compile_type, - "profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}}, - }, - } - manager = MSCManager(model, config, plugins=plugins) - report = manager.run_pipe() - model_info = manager.get_runtime().model_info - manager.destory() - assert report["success"], f"Failed to run pipe for torch -> {compile_type}" - assert msc_utils.dict_equal(model_info, expected_info), ( - f"Model info {model_info} mismatch with expected {expected_info}" - ) - - -@pytest.mark.skip( - reason="skip the test because plugin needs to include ffi folder, can be re-enabled" -) -def test_plugin(): - """Test the plugins""" - - frameworks = [MSCFramework.TORCH, MSCFramework.TVM] - if tvm.get_global_func("relax.ext.tensorrt", True) is not None: - frameworks.append(MSCFramework.TENSORRT) - plugin_root = msc_utils.msc_dir("msc_plugin") - managers = _build_plugin(frameworks, plugin_root) - - # test the plugin load - _test_tvm_plugin(managers[MSCFramework.TVM], "llvm") - if tvm.cuda().exist: - _test_tvm_plugin(managers[MSCFramework.TVM], "cuda") - _test_torch_plugin(managers[MSCFramework.TORCH]) - - # test the plugin with manager - model_info = { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1, "nn.max_pool2d": 1}, - } - _test_with_manager(managers, MSCFramework.TORCH, model_info) - _test_with_manager(managers, MSCFramework.TVM, model_info) - if tvm.get_global_func("relax.ext.tensorrt", True) is not None: - byoc_info = { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": ""} - ], - "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, - } - _test_with_manager(managers, MSCFramework.TENSORRT, byoc_info) - - plugin_root.destory() - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py deleted file mode 100644 index 0b53d00517c8..000000000000 --- a/tests/python/contrib/test_msc/test_runner.py +++ /dev/null @@ -1,122 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 - -"""Test Runners in MSC.""" - -import numpy as np -import pytest -import torch -from torch import fx - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.framework.tensorrt.runtime import TensorRTRunner -from tvm.contrib.msc.framework.torch.runtime import TorchRunner -from tvm.contrib.msc.framework.tvm.runtime import TVMRunner -from tvm.relax.frontend.torch import from_fx - -requires_tensorrt = pytest.mark.skipif( - tvm.get_global_func("relax.ext.tensorrt", True) is None, - reason="TENSORRT is not enabled", -) - - -def _get_torch_model(name, training=False): - """Get model from torch vision""" - - # pylint: disable=import-outside-toplevel - try: - import torchvision - - model = getattr(torchvision.models, name)() - if training: - model = model.train() - else: - model = model.eval() - return model - except: # pylint: disable=bare-except - print("please install torchvision package") - return None - - -def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): - """Test runner from torch model""" - - torch_model = _get_torch_model("resnet50", training) - if torch_model: - path = f"test_runner_torch_{runner_cls.__name__}_{device}" - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) - log_path = workspace.relpath("MSC_LOG", keep_history=False) - msc_utils.set_global_logger("critical", log_path) - input_info = [([1, 3, 224, 224], "float32")] - datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] - torch_datas = [torch.from_numpy(d) for d in datas] - graph_model = fx.symbolic_trace(torch_model) - if training: - input_info = [([tvm.tir.Var("bz", "int64"), 3, 224, 224], "float32")] - with torch.no_grad(): - golden = torch_model(*torch_datas) - mod = from_fx(graph_model, input_info) - runner = runner_cls(mod, device=device, training=training) - runner.build() - outputs = runner.run(datas, ret_type="list") - golden = [msc_utils.cast_array(golden)] - workspace.destory() - for gol_r, out_r in zip(golden, outputs): - tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("training", [True, False]) -def test_tvm_runner_cpu(training): - """Test runner for tvm on cpu""" - - _test_from_torch(TVMRunner, "cpu", training=training) - - -@tvm.testing.requires_cuda -@pytest.mark.parametrize("training", [True, False]) -def test_tvm_runner_cuda(training): - """Test runner for tvm on CUDA""" - - _test_from_torch(TVMRunner, "cuda", training=training) - - -@pytest.mark.parametrize("training", [True, False]) -def test_torch_runner_cpu(training): - """Test runner for torch on cpu""" - - _test_from_torch(TorchRunner, "cpu", training=training) - - -@tvm.testing.requires_cuda -@pytest.mark.parametrize("training", [True, False]) -def test_torch_runner_cuda(training): - """Test runner for torch on CUDA""" - - _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) - - -@requires_tensorrt -def test_tensorrt_runner(): - """Test runner for tensorrt""" - - _test_from_torch(TensorRTRunner, "cuda", atol=1e-1, rtol=1e-1) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py deleted file mode 100644 index 59315b605df1..000000000000 --- a/tests/python/contrib/test_msc/test_tools.py +++ /dev/null @@ -1,312 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 - -"""Test Tools in MSC.""" - -import json - -import pytest -import torch - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.pipeline import MSCManager - -requires_tensorrt = pytest.mark.skipif( - tvm.get_global_func("relax.ext.tensorrt", True) is None, - reason="TENSORRT is not enabled", -) - - -def _get_config( - model_type, - compile_type, - tools, - inputs, - outputs, - atol=1e-2, - rtol=1e-2, - optimize_type=None, -): - """Get msc config""" - - path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) - return { - "workspace": msc_utils.msc_dir(path, keep_history=False), - "verbose": "critical", - "model_type": model_type, - "inputs": inputs, - "outputs": outputs, - "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, - "tools": tools, - "prepare": {"profile": {"benchmark": {"repeat": 10}}}, - "baseline": { - "run_type": model_type, - "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - }, - "optimize": { - "run_type": optimize_type or model_type, - "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - }, - "compile": { - "run_type": compile_type, - "profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}}, - }, - } - - -def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC): - """Get config for the tool""" - - tools = [] - if tool_type == ToolType.PRUNER: - config = { - "plan_file": "msc_pruner.json", - "strategys": [ - { - "methods": { - "weights": {"method_name": "per_channel", "density": 0.8}, - "output": {"method_name": "per_channel", "density": 0.8}, - } - } - ], - } - tools.append({"tool_type": ToolType.PRUNER, "tool_config": config}) - elif tool_type == ToolType.QUANTIZER: - # pylint: disable=import-outside-toplevel - from tvm.contrib.msc.core.tools.quantize import QuantizeStage - - if run_type == MSCFramework.TENSORRT: - config = {"plan_file": "msc_quantizer.json", "strategys": []} - else: - op_types = ["nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias"] - config = { - "plan_file": "msc_quantizer.json", - "strategys": [ - { - "methods": { - "input": "gather_maxmin", - "output": "gather_maxmin", - "weights": "gather_max_per_channel", - }, - "op_types": op_types, - "stages": [QuantizeStage.GATHER], - }, - { - "methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"}, - "op_types": op_types, - "stages": [QuantizeStage.CALIBRATE], - }, - { - "methods": { - "input": "quantize_normal", - "weights": "quantize_normal", - "output": "dequantize_normal", - }, - "op_types": op_types, - }, - ], - } - tools.append({"tool_type": ToolType.QUANTIZER, "tool_config": config}) - elif tool_type == ToolType.TRACKER: - # pylint: disable=import-outside-toplevel - from tvm.contrib.msc.core.utils import MSCStage - - config = { - "plan_file": "msc_tracker.json", - "strategys": [ - { - "methods": { - "output": { - "method_name": "save_compared", - "compare_to": { - MSCStage.OPTIMIZE: [MSCStage.BASELINE], - MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE], - }, - } - }, - "op_types": ["nn.relu"], - } - ], - } - tools.append({"tool_type": ToolType.TRACKER, "tool_config": config}) - if use_distill: - config = { - "plan_file": "msc_distiller.json", - "strategys": [ - { - "methods": {"mark": "loss_lp_norm"}, - "marks": ["loss"], - }, - ], - } - tools.append({"tool_type": ToolType.DISTILLER, "tool_config": config}) - return tools - - -def _get_torch_model(name, training=False): - """Get model from torch vision""" - - # pylint: disable=import-outside-toplevel - try: - import torchvision - - model = getattr(torchvision.models, name)() - if training: - model = model.train() - else: - model = model.eval() - return model - except: # pylint: disable=bare-except - print("please install torchvision package") - return None - - -def _check_manager(manager, expected_info): - """Check the manager results""" - - model_info = manager.get_runtime().model_info - passed, err = True, "" - if not manager.report["success"]: - passed = False - err = f"Failed to run pipe for {manager.model_type} -> {manager.compile_type}" - if not msc_utils.dict_equal(model_info, expected_info): - passed = False - err = f"Model info {model_info} mismatch with expected {expected_info}" - manager.destory() - if not passed: - raise Exception(f"{err}\nReport:{json.dumps(manager.report, indent=2)}") - - -def _test_from_torch( - compile_type, - tools, - expected_info, - training=False, - atol=1e-1, - rtol=1e-1, - optimize_type=None, -): - torch_model = _get_torch_model("resnet50", training) - if torch_model: - if torch.cuda.is_available(): - torch_model = torch_model.to(torch.device("cuda:0")) - config = _get_config( - MSCFramework.TORCH, - compile_type, - tools, - inputs=[["input_0", [1, 3, 224, 224], "float32"]], - outputs=["output"], - atol=atol, - rtol=rtol, - optimize_type=optimize_type, - ) - manager = MSCManager(torch_model, config) - manager.run_pipe() - _check_manager(manager, expected_info) - - -def get_model_info(compile_type): - """Get the model info""" - - if compile_type == MSCFramework.TVM: - return { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], - "nodes": { - "total": 229, - "input": 1, - "nn.conv2d": 53, - "nn.batch_norm": 53, - "get_item": 53, - "nn.relu": 49, - "nn.max_pool2d": 1, - "add": 16, - "nn.adaptive_avg_pool2d": 1, - "reshape": 1, - "msc.linear_bias": 1, - }, - } - if compile_type == MSCFramework.TENSORRT: - return { - "inputs": [ - {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, - } - raise TypeError("Unexpected compile_type " + str(compile_type)) - - -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER]) -def test_tvm_tool(tool_type): - """Test tools for tvm""" - - tools = get_tools(tool_type) - _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False) - - -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) -def test_tvm_distill(tool_type): - """Test tools for tvm with distiller""" - - tools = get_tools(tool_type, use_distill=True) - _test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False) - - -@requires_tensorrt -@pytest.mark.parametrize( - "tool_type", - [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER], -) -def test_tensorrt_tool(tool_type): - """Test tools for tensorrt""" - - tools = get_tools(tool_type, run_type=MSCFramework.TENSORRT) - if tool_type == ToolType.QUANTIZER: - optimize_type = MSCFramework.TENSORRT - else: - optimize_type = None - _test_from_torch( - MSCFramework.TENSORRT, - tools, - get_model_info(MSCFramework.TENSORRT), - training=False, - atol=1e-1, - rtol=1e-1, - optimize_type=optimize_type, - ) - - -@requires_tensorrt -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER]) -def test_tensorrt_distill(tool_type): - """Test tools for tensorrt with distiller""" - - tools = get_tools(tool_type, use_distill=True) - _test_from_torch( - MSCFramework.TENSORRT, tools, get_model_info(MSCFramework.TENSORRT), training=False - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_transform.py b/tests/python/contrib/test_msc/test_transform.py deleted file mode 100644 index 8c1c269bb831..000000000000 --- a/tests/python/contrib/test_msc/test_transform.py +++ /dev/null @@ -1,113 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E722 - -"""Test MSC basic Pass.""" - -import tvm.testing -from tvm.contrib.msc.core import transform as msc_transform -from tvm.contrib.msc.core import utils as msc_utils -from tvm.relax import PyExprVisitor -from tvm.relax.frontend.torch import from_fx - - -def test_relax_layout(): - """Test SetExprLayout for relax""" - - # pylint: disable=import-outside-toplevel - try: - import torch - import torchvision - from torch import fx - except: # pylint: disable=bare-except - print("please install pytorch python package") - return - - class RelaxLayoutChecker(PyExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._missing_exprs) == 0, f"Missing {len(self._missing_exprs)} layouts" - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - if not msc_utils.get_expr_layout(binding.value): - self._missing_exprs.append(binding.value) - - def visit_constant_(self, op) -> None: - super().visit_constant_(op) - if not msc_utils.get_expr_layout(op): - self._missing_exprs.append(op) - - torch_model = torchvision.models.resnet50() - graph_model = fx.symbolic_trace(torch_model) - input_info = [([1, 3, 224, 224], "float32")] - with torch.no_grad(): - mod = from_fx(graph_model, input_info) - mod = msc_transform.SetExprLayout()(mod) - RelaxLayoutChecker().check(mod) - - -def test_relax(): - """Test SetExprName for relax""" - - # pylint: disable=import-outside-toplevel - try: - import torch - import torchvision - from torch import fx - except: # pylint: disable=bare-except - print("please install pytorch python package") - return - - class RelaxNameChecker(PyExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._missing_exprs) == 0, f"Missing {len(self._missing_exprs)} names" - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - if not msc_utils.get_expr_name(binding.value): - self._missing_exprs.append(binding.value) - - def visit_constant_(self, op) -> None: - super().visit_constant_(op) - if not msc_utils.get_expr_name(op): - self._missing_exprs.append(op) - - torch_model = torchvision.models.resnet50() - graph_model = fx.symbolic_trace(torch_model) - input_info = [([1, 3, 224, 224], "float32")] - with torch.no_grad(): - mod = from_fx(graph_model, input_info) - mod = msc_transform.SetExprName()(mod) - RelaxNameChecker().check(mod) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py deleted file mode 100644 index 2b616eb33adb..000000000000 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ /dev/null @@ -1,1255 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test translate from relax.""" - -import numpy as np -import torch -from torch.nn import Module - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.frontend import translate as core_translate -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate -from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen - - -def verify_model(torch_model, input_info, opt_config=None): - """Compare torch module IR""" - - orig_mod, _ = torch_translate.from_torch(torch_model, input_info, as_msc=False) - target = "llvm" - dev = tvm.cpu() - args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info] - - def _tvm_runtime_to_np(obj): - if isinstance(obj, tvm.runtime.Tensor): - return obj.numpy() - elif isinstance(obj, tvm.runtime.ShapeTuple): - return np.array(obj, dtype="int64") - elif isinstance(obj, list | tvm.ir.container.Array): - return [_tvm_runtime_to_np(item) for item in obj] - elif isinstance(obj, tuple): - return tuple(_tvm_runtime_to_np(item) for item in obj) - else: - return obj - - def _run_relax(relax_mod): - relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) - relax_exec = tvm.compile(relax_mod, target) - vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) - res = vm_runner["main"](*args) - return _tvm_runtime_to_np(res) - - rt_mod = tvm_codegen.to_relax( - *core_translate.from_relax(orig_mod, opt_config=opt_config), - codegen_config={"explicit_name": False}, - ) - - orig_output = _run_relax(orig_mod) - rt_output = _run_relax(rt_mod) - if not isinstance(orig_output, list | tuple): - orig_output = [orig_output] - if not isinstance(rt_output, list | tuple): - rt_output = [rt_output] - for o_out, r_out in zip(orig_output, rt_output): - tvm.testing.assert_allclose(o_out, r_out) - - -def test_conv1d(): - """test relax translator for conv1d""" - - class Conv1D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv1D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10], "float32")] - verify_model(Conv1D1(), input_info) - verify_model(Conv1D2(), input_info) - - -def test_conv2d(): - """test relax translator for conv2d""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info) - verify_model(Conv2D2(), input_info) - - -def test_linear(): - """test relax translator for linear""" - - class Dense1(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=True) - - def forward(self, data): - return self.linear(data) - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) - - -def test_bmm(): - """test relax translator for bmm""" - - class BMM(Module): - def forward(self, x, y): - return torch.bmm(x, y) - - input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - verify_model(BMM(), input_info) - - -def test_baddbmm(): - """test relax translator for baddbmm""" - - class BAddBMM1(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y) - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - - input_info = [ - ((4, 128, 512), "float32"), - ((4, 128, 256), "float32"), - ((4, 256, 512), "float32"), - ] - verify_model(BAddBMM1(), input_info) - verify_model(BAddBMM2(), input_info) - - -def test_relu(): - """test relax translator for relu""" - - class ReLU(Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, data): - return self.relu(data) - - class ReLU1(Module): - def forward(self, data): - return torch.nn.functional.relu(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU(), input_info) - verify_model(ReLU1(), input_info) - - -def test_relu6(): - """test relax translator for relu6""" - - class ReLU6(Module): - def __init__(self): - super().__init__() - self.relu6 = torch.nn.ReLU6() - - def forward(self, data): - return self.relu6(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU6(), input_info) - - -def test_maxpool2d(): - """test relax translator for maxpool2d""" - - class MaxPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d3(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info) - verify_model(MaxPool2d2(), input_info) - verify_model(MaxPool2d3(), input_info) - - -def test_avgpool2d(): - """test relax translator for avgpool2d""" - - class AvgPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class AvgPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info) - verify_model(AvgPool2d2(), input_info) - - -def test_adaptive_avgpool2d(): - """test relax translator for adaptive_avgpool2d""" - - class AdaptiveAvgPool2d0(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info) - - -def test_flatten(): - """test relax translator for flatten""" - - class Flatten(Module): - def __init__(self): - super().__init__() - self.f = torch.nn.Flatten(2, -1) - - def forward(self, data): - return self.f(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Flatten(), input_info) - verify_model(torch.nn.Flatten(2, -1), input_info) - - -def test_batchnorm2d(): - """test relax translator for batchnorm2d""" - - class BatchNorm2d(Module): - def __init__(self): - super().__init__() - self.batchnorm = torch.nn.BatchNorm2d(3) - - def forward(self, data): - return self.batchnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(BatchNorm2d(), input_info) - - -def test_embedding(): - """test relax translator for embedding""" - - class Embedding(Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(10, 3) - - def forward(self, data): - return self.embedding(data) - - verify_model(Embedding(), [([4], "int64")]) - verify_model(Embedding(), [([4, 5], "int64")]) - - -def test_dropout(): - """test relax translator for dropout""" - - class Dropout1(Module): - def __init__(self): - super().__init__() - self.dropout = torch.nn.Dropout(0.5) - - def forward(self, data): - return self.dropout(data) - - class Dropout2(Module): - def forward(self, data): - return torch.dropout(data, 0.5, train=True) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dropout1(), input_info) - verify_model(Dropout2(), input_info) - - -def test_layernorm(): - """test relax translator for layernorm""" - - class LayerNorm(Module): - def __init__(self): - super().__init__() - self.layernorm = torch.nn.LayerNorm((10, 10)) - - def forward(self, data): - return self.layernorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info) - - -def test_functional_layernorm(): - """test relax translator for functional_layernorm""" - - class LayerNorm(Module): - def __init__(self, shape): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(shape)) - self.bias = torch.nn.Parameter(torch.zeros(shape)) - - def forward(self, data): - return torch.nn.functional.layer_norm( - data, self.weight.shape, self.weight, self.bias, 1e-5 - ) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm((10, 10)), input_info) - - -def test_cross_entropy(): - """test relax translator for cross_entropy""" - - class CrossEntropy1(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss() - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy2(Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones((2,))) - self.loss = torch.nn.CrossEntropyLoss(weight=self.weight) - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy3(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum") - - def forward(self, logits, targets): - return self.loss(logits, targets) - - input_info = [([3, 2], "float32"), ([3], "int32")] - verify_model(CrossEntropy1(), input_info) - verify_model(CrossEntropy2(), input_info) - verify_model(CrossEntropy3(), input_info) - - -def test_functional_cross_entropy(): - """test relax translator for functional_cross_entropy""" - - class CrossEntropy(Module): - def forward(self, logits, targets): - return torch.nn.functional.cross_entropy(logits, targets) - - input_info = [([3, 10], "float32"), ([3], "int32")] - verify_model(CrossEntropy(), input_info) - - -def test_silu(): - """test relax translator for silu""" - - class SiLU(Module): - def __init__(self): - super().__init__() - self.silu = torch.nn.SiLU() - - def forward(self, data): - return self.silu(data) - - class SiLU2(Module): - def forward(self, data): - return torch.nn.functional.silu(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info) - verify_model(SiLU2(), input_info) - - -def test_groupnorm(): - """test relax translator for groupnorm""" - - class GroupNorm(Module): - def __init__(self): - super().__init__() - self.groupnorm = torch.nn.GroupNorm(3, 3) - - def forward(self, data): - return self.groupnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info) - - -def test_softmax(): - """test relax translator for softmax""" - - class Softmax(Module): - def __init__(self): - super().__init__() - self.softmax = torch.nn.Softmax(dim=1) - - def forward(self, data): - return self.softmax(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info) - - -def test_binary(): - """test relax translator for binary""" - - input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] - input_info2 = [([1, 3, 10, 10], "float32")] - - # Add - class Add1(Module): - def forward(self, lhs, rhs): - return lhs + rhs - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - - verify_model(Add1(), input_info1) - verify_model(Add2(), input_info2) - - # Sub - class Sub1(Module): - def forward(self, lhs, rhs): - return lhs - rhs - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - - verify_model(Sub1(), input_info1) - verify_model(Sub2(), input_info2) - - # Mul - class Mul1(Module): - def forward(self, lhs, rhs): - return lhs * rhs - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - - verify_model(Mul1(), input_info1) - verify_model(Mul2(), input_info2) - - # True div - class TrueDiv1(Module): - def forward(self, lhs, rhs): - return lhs / rhs - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - - verify_model(TrueDiv1(), input_info1) - verify_model(TrueDiv2(), input_info2) - - # Floor div - class FloorDiv1(Module): - def forward(self, lhs, rhs): - return lhs // rhs - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - - verify_model(FloorDiv1(), input_info1) - verify_model(FloorDiv2(), input_info2) - - # Power - class Power1(Module): - def forward(self, lhs, rhs): - return lhs**rhs - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - - verify_model(Power1(), input_info1) - verify_model(Power2(), input_info2) - - # LT - class LT1(Module): - def forward(self, lhs, rhs): - return lhs < rhs - - class LT2(Module): - def forward(self, lhs): - return lhs < 1.0 - - verify_model(LT1(), input_info1) - verify_model(LT2(), input_info2) - - -def test_size(): - """test relax translator for size""" - - class Size(Module): - def forward(self, data): - return data.size() - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Size(), input_info) - - -def test_squeeze(): - """test relax translator for squeeze""" - - class Squeeze1(Module): - def forward(self, data): - return data.squeeze(1) - - class Squeeze2(Module): - def forward(self, data): - return data.squeeze() - - input_info = [([3, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info) - verify_model(Squeeze2(), input_info) - - -def test_unsqueeze(): - """test relax translator for unsqueeze""" - - class Unsqueeze1(Module): - def forward(self, data): - return data.unsqueeze(1) - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info) - verify_model(Unsqueeze2(), input_info) - - -def test_getattr(): - """test relax translator for getattr""" - - class GetAttr1(Module): - def forward(self, data): - return data.shape - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(GetAttr1(), input_info) - - -def test_getitem(): - """test relax translator for getitem""" - - class Slice1(Module): - def forward(self, x): - return x[0, 1::2, :, :3] - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - - verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - verify_model(Slice2(), [([8, 16], "float32")]) - - -def test_unary(): - """test relax translator for unary""" - - input_info = [([1, 3, 10, 10], "float32")] - - # sin - class Sin(Module): - def forward(self, data): - return torch.sin(data) - - verify_model(Sin(), input_info) - - # cos - class Cos(Module): - def forward(self, data): - return torch.cos(data) - - verify_model(Cos(), input_info) - - # exp - class Exp(Module): - def forward(self, data): - return torch.exp(data) - - verify_model(Exp(), input_info) - - # sqrt - class Sqrt(Module): - def forward(self, data): - return torch.sqrt(data) - - verify_model(Sqrt(), input_info) - - # sigmoid - class Sigmoid(Module): - def forward(self, data): - return torch.sigmoid(data) - - verify_model(Sigmoid(), input_info) - - # round - class Round(Module): - def forward(self, data): - return torch.round(data) - - verify_model(Round(), input_info) - - -def test_gelu(): - """test relax translator for gelu""" - - class Gelu(Module): - def forward(self, data): - return torch.nn.functional.gelu(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Gelu(), input_info) - - -def test_tanh(): - """test relax translator for tanh""" - - class Tanh(Module): - def forward(self, data): - return torch.tanh(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info) - - -def test_clamp(): - """test relax translator for clamp""" - - class Clamp(Module): - def forward(self, data): - return torch.clamp(data, min=0.1, max=0.5) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info) - - -def test_interpolate(): - """test relax translator for interpolate""" - - class Interpolate(Module): - def forward(self, data): - return torch.nn.functional.interpolate(data, (5, 5)) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info) - - -def test_addmm(): - """test relax translator for addmm""" - - class Addmm(Module): - def forward(self, x_1, x_2, x_3): - return torch.addmm(x_1, x_2, x_3) - - input_info = [ - ([10, 10], "float32"), - ([10, 10], "float32"), - ([10, 10], "float32"), - ] - verify_model(Addmm(), input_info) - - -def test_split(): - """test relax translator for split""" - - class Split1(Module): - def forward(self, data): - return torch.split(data, 1, dim=1) - - class Split2(Module): - def forward(self, data): - return torch.split(data, [1, 2], dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split1(), input_info) - verify_model(Split2(), input_info) - - -def test_unbind(): - """test relax translator for unbind""" - - class Unbind1(Module): - def forward(self, data): - return torch.unbind(data) - - class Unbind2(Module): - def forward(self, data): - return torch.unbind(data, dim=1) - - input_info = [([3, 3, 10, 10], "float32")] - verify_model(Unbind1(), input_info) - verify_model(Unbind2(), input_info) - - -def test_cumsum(): - """test relax translator for cumsum""" - - class Cumsum(Module): - def forward(self, data): - return torch.cumsum(data, dim=1, dtype=torch.int32) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Cumsum(), input_info) - - -def test_chunk(): - """test relax translator for chunk""" - - class Chunk(Module): - def forward(self, data): - return torch.chunk(data, 3, dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info) - - -def test_inplace_fill(): - """test relax translator for inplace_fill""" - - class InplaceFill(Module): - def forward(self, data): - data.fill_(1.5) - return data - - verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) - - -def test_arange(): - """test relax translator for arange""" - - class Arange(Module): - def forward(self): - return torch.arange(0, 20, dtype=torch.int32) - - verify_model(Arange(), [([10, 10], "float32")]) - - -def test_empty(): - """test relax translator for empty""" - - class Empty(Module): - def forward(self): - return torch.empty((10, 10), dtype=torch.float32) - - verify_model(Empty(), [([10, 10], "float32")]) - - -def test_tensor(): - """test relax translator for tensor""" - - class Empty1(Module): - def forward(self): - return torch.tensor(3, dtype=torch.float32) - - class Empty2(Module): - def forward(self): - return torch.tensor(3) - - verify_model(Empty1(), [([10, 10], "float32")]) - verify_model(Empty2(), [([10, 10], "float32")]) - - -def test_tril(): - """test relax translator for tril""" - - class Tril(Module): - def forward(self, data): - return torch.tril(data, 1) - - class InplaceTril(Module): - def forward(self, data): - data.tril_(1) - return data - - input_info = [([10, 10], "float32")] - verify_model(Tril(), input_info) - verify_model(InplaceTril(), input_info) - - -def test_triu(): - """test relax translator for triu""" - - class Triu(Module): - def forward(self, data): - return torch.triu(data, 1) - - class InplaceTriu(Module): - def forward(self, data): - data.triu_(1) - return data - - input_info = [([10, 10], "float32")] - verify_model(Triu(), input_info) - verify_model(InplaceTriu(), input_info) - - -def test_new_ones(): - """test relax translator for new_ones""" - - class NewOnes(Module): - def forward(self, x): - return x.new_ones(1, 2, 3) - - input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) - - -def test_expand(): - """test relax translator for expand""" - - class Expand1(Module): - def forward(self, x): - return x.expand(4, 2, 3, 4) - - class Expand2(Module): - def forward(self, x): - return x.expand(4, -1, -1, 4) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand1(), input_info) - verify_model(Expand2(), input_info) - - -def test_reduce(): - """test relax translator for reduce""" - - # sum - class Sum(Module): - def forward(self, x): - return torch.sum(x, (2, 1)) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Sum(), input_info) - - -def test_datatype(): - """test relax translator for datatype""" - - input_info = [([1, 2, 3, 4], "float32")] - - # float - class ToFloat(Module): - def forward(self, x): - return x.float() - - verify_model(ToFloat(), input_info) - - # half - class ToHalf(Module): - def forward(self, x): - return x.half() - - verify_model(ToHalf(), input_info) - - # type - class Type(Module): - def forward(self, x): - return x.type(torch.float32) - - # type - class TypeFromAttr(Module): - def forward(self, x): - return x.type(x.getattr("dtype")) - - # astype - class AsType(Module): - def forward(self, x): - return x.astype(torch.float32) - - verify_model(Type(), input_info) - verify_model(TypeFromAttr(), input_info) - verify_model(AsType(), input_info) - - -def test_permute(): - """test relax translator for permute""" - - class Permute(Module): - def forward(self, x): - return x.permute(0, 3, 2, 1) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Permute(), input_info) - - -def test_reshape(): - """test relax translator for reshape""" - - class Reshape(Module): - def forward(self, x): - return x.reshape(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info) - - -def test_transpose(): - """test relax translator for transpose""" - - class Transpose(Module): - def forward(self, x): - return x.transpose(1, 3) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Transpose(), input_info) - - -def test_view(): - """test relax translator for view""" - - class View(Module): - def forward(self, x): - return x.view(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(View(), input_info) - - -def test_keep_params(): - """test relax translator for keep_params""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) - - -def test_unwrap_unit_return_tuple(): - """test relax translator for unwrap_unit_return_tuple""" - - class Identity(Module): - def forward(self, x): - return (x,) - - verify_model(Identity(), [([256, 256], "float32")]) - - -def test_no_bind_return_tuple(): - """test relax translator for no_bind_return_tuple""" - - class Identity(Module): - def forward(self, x, y): - return (x, y) - - input_info = [([256, 256], "float32"), ([256, 256], "float32")] - verify_model(Identity(), input_info) - - -def test_argmax(): - """test relax translator for argmax""" - - class Argmax1(Module): - def forward(self, data): - return torch.argmax(data, dim=-1) - - class Argmax2(Module): - def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True) - - verify_model(Argmax1(), [([256, 256], "float32")]) - verify_model(Argmax2(), [([256, 256], "float32")]) - - -def test_argmin(): - """test relax translator for argmin""" - - class Argmin1(Module): - def forward(self, data): - return torch.argmin(data) - - class Argmin2(Module): - def forward(self, data): - return torch.argmin(data, keepdim=True) - - verify_model(Argmin1(), [([256, 256], "float32")]) - verify_model(Argmin2(), [([256, 256], "float32")]) - - -def test_to(): - """test relax translator for to""" - - class To1(Module): - def forward(self, data): - return data.to(torch.float16) - - class To2(Module): - def forward(self, data): - return data.to("cpu") - - verify_model(To1(), [([256, 256], "float32")]) - verify_model(To2(), [([256, 256], "float32")]) - - -def test_mean(): - """test relax translator for mean""" - - class Mean(Module): - def forward(self, data): - return data.mean(-1) - - class MeanKeepDim(Module): - def forward(self, data): - return data.mean(-1, keepdim=True) - - verify_model(Mean(), [([256, 256], "float32")]) - verify_model(MeanKeepDim(), [([256, 256], "float32")]) - - -def test_rsqrt(): - """test relax translator for rsqrt""" - - class Rsqrt(Module): - def forward(self, data): - return torch.rsqrt(data) - - verify_model(Rsqrt(), [([256, 256], "float32")]) - - -def test_neg(): - """test relax translator for neg""" - - class Neg(Module): - def forward(self, data): - return -data - - verify_model(Neg(), [([256, 256], "float32")]) - - -def test_max(): - """test relax translator for max""" - - class Max(Module): - def forward(self, x, y): - return torch.max(x, y) - - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) - - -def test_cat(): - """test relax translator for cat""" - - class Cat1(Module): - def forward(self, data, data1, data2): - return torch.cat((data, data1, data2), dim=1) - - class Cat2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.cat((data, const1, const2), dim=1) - - input_info = [ - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ] - verify_model(Cat1(), input_info) - verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) - - -def test_stack(): - """test relax translator for stack""" - - class Stack1(Module): - def forward(self, data, data1, data2): - return torch.stack((data, data1, data2), dim=0) - - class Stack2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.stack((data, const1, const2), dim=1) - - input_info = [ - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ] - verify_model(Stack1(), input_info) - verify_model(Stack2(), [([1, 3, 10, 10], "float32")]) - - -def test_scatter(): - """test relax translator for scatter""" - - class Scatter1(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - return data.scatter(dim=0, index=self.index, src=src) - - class Scatter2(Module): - def forward(self, data, index, src): - return data.scatter(0, index, src) - - verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")]) - verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) - - -def test_masked_scatter(): - """test relax translator for masked_scatter""" - - class MaskedScatter1(Module): - def __init__(self): - super().__init__() - self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH) - - def forward(self, data, src): - return data.masked_scatter(self.mask, src) - - class MaskedScatter2(Module): - def __init__(self): - super().__init__() - self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH) - - def forward(self, data, src): - return data.masked_scatter(self.mask, src) - - verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")]) - verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) - - -def test_attention(): - """test relax translator for attention""" - - # pylint: disable=import-outside-toplevel - import torch.nn.functional as F - - class Attention1(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data) - - class Attention2(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True) - - input_info = [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ] - verify_model(Attention1(), input_info) - verify_model(Attention2(), input_info) - - class Attention3(Module): - def forward(self, q_data, k_data, v_data, mask): - return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - - verify_model( - Attention3(), - [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 128], "float32"), - ], - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py deleted file mode 100644 index 422411fa3576..000000000000 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ /dev/null @@ -1,918 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test translate for TensorrRT.""" - -import pytest -import torch -from torch import fx -from torch.nn import Module - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.framework.tensorrt import codegen -from tvm.contrib.msc.framework.tensorrt.frontend import translate -from tvm.relax import PyExprVisitor -from tvm.relax.frontend.torch import from_fx - -requires_tensorrt = pytest.mark.skipif( - tvm.get_global_func("relax.ext.tensorrt", True) is None, - reason="TENSORRT is not enabled", -) - - -def build_and_run(mod, inputs): - """Build and run the virtual machine""" - - target = tvm.target.Target("cuda") - mod = tvm.relax.transform.LegalizeOps()(mod) - with target: - mod = tvm.s_tir.transform.DefaultGPUSchedule()(mod) - with tvm.transform.PassContext(opt_level=3): - rt_mod = tvm.compile(mod, target) - runnable = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) - res = runnable["main"](*inputs) - if isinstance(res, tvm.runtime.Tensor): - return [res.numpy()] - return [e.numpy() for e in res] - - -def check_names(mod): - """Check the byoc name and unique_name""" - - @tvm.relax.expr_functor.visitor - class NameChecker(PyExprVisitor): - """Checker to check if any non-target ops exist""" - - def check(self, expr): - self._recorded_names = set() - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - - def visit_function_(self, op: tvm.relax.Function) -> None: - if "Composite" in op.attrs: - assert "Unique" in op.attrs, "Can not find unique_name for func " + str(op) - name = str(op.attrs["Unique"]) - assert name not in self._recorded_names, f"Name {name} is already in use" - self._recorded_names.add(name) - super().visit_function_(op) - - def _is_target_func(func): - if "Codegen" not in func.attrs: - return False - return func.attrs["Codegen"] == "msc_tensorrt" - - for _, func in mod.functions.items(): - if not _is_target_func(func): - continue - assert "Unique" in func.attrs, "Can not find Unique from function attributes" - NameChecker().check(func) - - -def verify_model(torch_model, input_info, **trans_config): - """Build model and verify results""" - - graph_model = fx.symbolic_trace(torch_model) - datas = [msc_utils.random_data(i) for i in input_info] - torch_datas = [torch.from_numpy(i) for i in datas] - with torch.no_grad(): - golden = torch_model(*torch_datas) - mod = from_fx(graph_model, input_info) - if not isinstance(golden, list | tuple): - golden = [golden] - golden = [g.detach().cpu().numpy() for g in golden] - # partition module for tensorrt - mod, graphs, weights = translate.partition_for_tensorrt(mod, trans_config=trans_config) - check_names(mod) - output_folder = msc_utils.msc_dir() - # tranalte to tensorrt - mod = codegen.to_tensorrt(mod, graphs, weights, output_folder=output_folder) - tvm_datas = [tvm.runtime.tensor(i, device=tvm.cuda()) for i in datas] - results = build_and_run(mod, tvm_datas) - for gol, res in zip(golden, results): - tvm.testing.assert_allclose(gol, res, atol=1e-3, rtol=1e-3) - output_folder.destory() - - -@requires_tensorrt -def test_conv1d(): - """test tensorrt translator for conv1d""" - - class Conv1D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv1D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10], "float32")] - verify_model(Conv1D1(), input_info) - verify_model(Conv1D2(), input_info) - - -@requires_tensorrt -def test_conv2d(): - """test tensorrt translator for conv2d""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info) - verify_model(Conv2D2(), input_info) - - -@requires_tensorrt -def test_linear(): - """test tensorrt translator for linear""" - - class Dense1(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=True) - - def forward(self, data): - return self.linear(data) - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(Dense1(), input_info, linear_to_conv=True) - verify_model(Dense2(), input_info, linear_to_conv=True) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) - - -@requires_tensorrt -def test_bmm(): - """test tensorrt translator for bmm""" - - class BMM(Module): - def forward(self, x, y): - return torch.bmm(x, y) - - input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - verify_model(BMM(), input_info) - - -@requires_tensorrt -def test_baddbmm(): - """test tensorrt translator for baddbmm""" - - class BAddBMM1(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y) - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - - input_info = [ - ((4, 128, 512), "float32"), - ((4, 128, 256), "float32"), - ((4, 256, 512), "float32"), - ] - verify_model(BAddBMM1(), input_info) - verify_model(BAddBMM2(), input_info) - - -@requires_tensorrt -def test_relu(): - """test tensorrt translator for relu""" - - class ReLU(Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, data): - return self.relu(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU(), input_info) - - -@requires_tensorrt -def test_relu6(): - """test tensorrt translator for relu6""" - - class ReLU6(Module): - def __init__(self): - super().__init__() - self.relu6 = torch.nn.ReLU6() - - def forward(self, data): - return self.relu6(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU6(), input_info) - - -@requires_tensorrt -def test_maxpool2d(): - """test tensorrt translator for maxpool2d""" - - class MaxPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info) - verify_model(MaxPool2d2(), input_info) - - -@requires_tensorrt -def test_avgpool2d(): - """test tensorrt translator for avgpool2d""" - - class AvgPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class AvgPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info) - verify_model(AvgPool2d2(), input_info) - - -@requires_tensorrt -def test_adaptive_avgpool2d(): - """test tensorrt translator for adaptive_avgpool2d""" - - class AdaptiveAvgPool2d0(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info) - - -@requires_tensorrt -def test_flatten(): - """test tensorrt translator for flatten""" - - class Flatten(Module): - def __init__(self): - super().__init__() - self.f = torch.nn.Flatten(2, -1) - - def forward(self, data): - return self.f(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Flatten(), input_info) - verify_model(torch.nn.Flatten(2, -1), input_info) - - -@requires_tensorrt -def test_batchnorm2d(): - """test tensorrt translator for batchnorm2d""" - - class BatchNorm2d(Module): - def __init__(self): - super().__init__() - self.batchnorm = torch.nn.BatchNorm2d(3) - - def forward(self, data): - return self.batchnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(BatchNorm2d().eval(), input_info) - - -@requires_tensorrt -def test_embedding(): - """test tensorrt translator for embedding""" - - class Embedding(Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(10, 3) - - def forward(self, data): - return self.embedding(data.to(torch.int64)) - - verify_model(Embedding(), [([4], "int32")]) - verify_model(Embedding(), [([4, 5], "int32")]) - - -@requires_tensorrt -def test_layernorm(): - """test tensorrt translator for layernorm""" - - class LayerNorm(Module): - def __init__(self): - super().__init__() - self.layernorm = torch.nn.LayerNorm((10, 10)) - - def forward(self, data): - return self.layernorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info) - - -@requires_tensorrt -def test_silu(): - """test tensorrt translator for silu""" - - class SiLU(Module): - def __init__(self): - super().__init__() - self.silu = torch.nn.SiLU() - - def forward(self, data): - return self.silu(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info) - - -@requires_tensorrt -def test_groupnorm(): - """test tensorrt translator for groupnorm""" - - class GroupNorm(Module): - def __init__(self): - super().__init__() - self.groupnorm = torch.nn.GroupNorm(3, 3) - - def forward(self, data): - return self.groupnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info) - - -@requires_tensorrt -def test_softmax(): - """test tensorrt translator for softmax""" - - class Softmax(Module): - def __init__(self): - super().__init__() - self.softmax = torch.nn.Softmax(dim=1) - - def forward(self, data): - return self.softmax(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info) - - -@requires_tensorrt -def test_binary(): - """test tensorrt translator for binary""" - - input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] - input_info2 = [([1, 3, 10, 10], "float32")] - - # Add - class Add1(Module): - def forward(self, lhs, rhs): - return lhs + rhs - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - - verify_model(Add1(), input_info1) - verify_model(Add2(), input_info2) - - # Sub - class Sub1(Module): - def forward(self, lhs, rhs): - return lhs - rhs - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - - verify_model(Sub1(), input_info1) - verify_model(Sub2(), input_info2) - - # Mul - class Mul1(Module): - def forward(self, lhs, rhs): - return lhs * rhs - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - - verify_model(Mul1(), input_info1) - verify_model(Mul2(), input_info2) - - # True div - class TrueDiv1(Module): - def forward(self, lhs, rhs): - return lhs / rhs - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - - verify_model(TrueDiv1(), input_info1) - verify_model(TrueDiv2(), input_info2) - - # Floor div - class FloorDiv1(Module): - def forward(self, lhs, rhs): - return lhs // rhs - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - - verify_model(FloorDiv1(), input_info1) - verify_model(FloorDiv2(), input_info2) - - # Power - class Power1(Module): - def forward(self, lhs, rhs): - return lhs**rhs - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - - verify_model(Power1(), input_info1) - verify_model(Power2(), input_info2) - - -@requires_tensorrt -def test_squeeze(): - """test tensorrt translator for squeeze""" - - class Squeeze1(Module): - def forward(self, data): - return data.squeeze(1) - - class Squeeze2(Module): - def forward(self, data): - return data.squeeze() - - input_info = [([3, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info) - verify_model(Squeeze2(), input_info) - - -@requires_tensorrt -def test_unsqueeze(): - """test tensorrt translator for unsqueeze""" - - class Unsqueeze1(Module): - def forward(self, data): - return data.unsqueeze(1) - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info) - verify_model(Unsqueeze2(), input_info) - - -@requires_tensorrt -def test_getitem(): - """test tensorrt translator for getitem""" - - class Slice1(Module): - def forward(self, x): - return x[0:1, 1::2, :, :3] - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - - verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - verify_model(Slice2(), [([8, 16], "float32")]) - - -@requires_tensorrt -def test_unary(): - """test tensorrt translator for unary""" - - input_info = [([1, 3, 10, 10], "float32")] - - # sin - class Sin(Module): - def forward(self, data): - return torch.sin(data) - - verify_model(Sin(), input_info) - - # cos - class Cos(Module): - def forward(self, data): - return torch.cos(data) - - verify_model(Cos(), input_info) - - # exp - class Exp(Module): - def forward(self, data): - return torch.exp(data) - - verify_model(Exp(), input_info) - - # sqrt - class Sqrt(Module): - def forward(self, data): - return torch.sqrt(data) - - verify_model(Sqrt(), input_info) - - # sigmoid - class Sigmoid(Module): - def forward(self, data): - return torch.sigmoid(data) - - verify_model(Sigmoid(), input_info) - - # round - class Round(Module): - def forward(self, data): - return torch.round(data) - - verify_model(Round(), input_info) - - -@requires_tensorrt -def test_tanh(): - """test tensorrt translator for tanh""" - - class Tanh(Module): - def forward(self, data): - return torch.tanh(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info) - - -@requires_tensorrt -def test_clamp(): - """test tensorrt translator for clamp""" - - class Clamp(Module): - def forward(self, data): - return torch.clamp(data, min=0.1, max=0.5) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info) - - -@requires_tensorrt -def test_interpolate(): - """test tensorrt translator for interpolate""" - - class Interpolate(Module): - def forward(self, data): - return torch.nn.functional.interpolate(data, (5, 5)) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info) - - -@requires_tensorrt -def test_addmm(): - """test tensorrt translator for addmm""" - - class Addmm(Module): - def forward(self, x_1, x_2, x_3): - return torch.addmm(x_1, x_2, x_3) - - input_info = [ - ([10, 10], "float32"), - ([10, 10], "float32"), - ([10, 10], "float32"), - ] - verify_model(Addmm(), input_info) - - -@requires_tensorrt -def test_split(): - """test tensorrt translator for split""" - - class Split1(Module): - def forward(self, data): - return torch.split(data, 1, dim=1) - - class Split2(Module): - def forward(self, data): - return torch.split(data, [1, 2], dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split1(), input_info) - verify_model(Split2(), input_info) - - -@requires_tensorrt -def test_unbind(): - """test tensorrt to relax for unbind""" - - class Unbind1(Module): - def forward(self, data): - return torch.unbind(data) - - class Unbind2(Module): - def forward(self, data): - return torch.unbind(data, dim=1) - - input_info = [([3, 3, 10, 10], "float32")] - verify_model(Unbind1(), input_info) - verify_model(Unbind2(), input_info) - - -@requires_tensorrt -def test_chunk(): - """test tensorrt translator for chunk""" - - class Chunk(Module): - def forward(self, data): - return torch.chunk(data, 3, dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info) - - -@requires_tensorrt -def test_expand(): - """test tensorrt translator for expand""" - - class Expand1(Module): - def forward(self, x): - x = x + 1.0 - return x.expand(4, 2, 3, 4) - - class Expand2(Module): - def forward(self, x): - x = x + 1.0 - return x.expand(4, -1, -1, 4) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand1(), input_info) - verify_model(Expand2(), input_info) - - -@requires_tensorrt -def test_reduce(): - """test tensorrt translator for reduce""" - - # sum - class Sum(Module): - def forward(self, x): - return torch.sum(x, (2, 1)) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Sum(), input_info) - - -@requires_tensorrt -def test_permute(): - """test tensorrt translator for permute""" - - class Permute(Module): - def forward(self, x): - return x.permute(0, 3, 2, 1) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Permute(), input_info) - - -@requires_tensorrt -def test_reshape(): - """test tensorrt translator for reshape""" - - class Reshape(Module): - def forward(self, x): - return x.reshape(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info) - - -@requires_tensorrt -def test_transpose(): - """test tensorrt translator for transpose""" - - class Transpose(Module): - def forward(self, x): - return x.transpose(1, 3) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Transpose(), input_info) - - -@requires_tensorrt -def test_view(): - """test tensorrt translator for view""" - - class View(Module): - def forward(self, x): - return x.view(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(View(), input_info) - - -@requires_tensorrt -def test_argmax(): - """test tensorrt translator for argmax""" - - class Argmax1(Module): - def forward(self, data): - return torch.argmax(data, dim=-1).to(torch.int32) - - class Argmax2(Module): - def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True).to(torch.int32) - - verify_model(Argmax1(), [([256, 256], "float32")]) - verify_model(Argmax2(), [([256, 256], "float32")]) - - -@requires_tensorrt -def test_argmin(): - """test tensorrt translator for argmin""" - - class Argmin1(Module): - def forward(self, data): - return torch.argmin(data, dim=-1).to(torch.int32) - - class Argmin2(Module): - def forward(self, data): - return torch.argmin(data, dim=-1, keepdim=True).to(torch.int32) - - verify_model(Argmin1(), [([256, 256], "float32")]) - verify_model(Argmin2(), [([256, 256], "float32")]) - - -@requires_tensorrt -def test_mean(): - """test tensorrt translator for mean""" - - class Mean(Module): - def forward(self, data): - return data.mean(-1) - - class MeanKeepDim(Module): - def forward(self, data): - return data.mean(-1, keepdim=True) - - verify_model(Mean(), [([256, 256], "float32")]) - verify_model(MeanKeepDim(), [([256, 256], "float32")]) - - -@requires_tensorrt -def test_rsqrt(): - """test tensorrt translator for rsqrt""" - - class Rsqrt(Module): - def forward(self, data): - return torch.rsqrt(data) - - verify_model(Rsqrt(), [([256, 256], "float32")]) - - -@requires_tensorrt -def test_neg(): - """test tensorrt translator for neg""" - - class Neg(Module): - def forward(self, data): - return -data - - verify_model(Neg(), [([256, 256], "float32")]) - - -@requires_tensorrt -def test_max(): - """test tensorrt translator for max""" - - class Max(Module): - def forward(self, x, y): - return torch.max(x, y) - - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) - - -@requires_tensorrt -def test_gelu(): - """test tensorrt translator for gelu""" - - class Gelu1(Module): - def forward(self, data): - return torch.nn.functional.gelu(data) - - class Gelu2(Module): - def forward(self, data): - return torch.nn.functional.gelu(data, approximate="tanh") - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Gelu1(), input_info) - verify_model(Gelu2(), input_info) - - -@requires_tensorrt -def test_cat(): - """test tensorrt translator for cat""" - - class Cat1(Module): - def forward(self, data, data1, data2): - return torch.cat((data, data1, data2), dim=1) - - class Cat2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.cat((data, const1, const2), dim=1) - - input_info = [ - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ] - verify_model(Cat1(), input_info) - verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py deleted file mode 100644 index 584ca414fbf7..000000000000 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ /dev/null @@ -1,1154 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test translate from torch.""" - -import torch -from torch.nn import Module - -import tvm.testing -from tvm.contrib.msc.core import utils as msc_utils -from tvm.contrib.msc.core.utils.namespace import MSCFramework -from tvm.contrib.msc.framework.torch import codegen -from tvm.contrib.msc.framework.torch.frontend import translate - - -def verify_model(torch_model, input_info): - """Compare torch module results""" - - torch_datas = [msc_utils.random_data(i, MSCFramework.TORCH) for i in input_info] - with torch.no_grad(): - golden = torch_model(*torch_datas) - graph, weights = translate.from_torch(torch_model, input_info) - model = codegen.to_torch(graph, weights) - with torch.no_grad(): - if not graph.get_inputs(): - result = model() - else: - result = model(*torch_datas) - if not isinstance(golden, list | tuple): - golden = [golden] - if not isinstance(result, list | tuple): - result = [result] - assert len(golden) == len(result), f"golden {len(golden)} mismatch with result {len(result)}" - for gol_r, new_r in zip(golden, result): - if isinstance(gol_r, torch.Tensor): - tvm.testing.assert_allclose( - gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5 - ) - else: - assert gol_r == new_r - - -def test_conv1d(): - """test torch translator for conv1d""" - - class Conv1D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv1D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10], "float32")] - verify_model(Conv1D1(), input_info) - verify_model(Conv1D2(), input_info) - - -def test_conv2d(): - """test torch translator for conv2d""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info) - verify_model(Conv2D2(), input_info) - - -def test_linear(): - """test torch translator for linear""" - - class Dense1(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=True) - - def forward(self, data): - return self.linear(data) - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) - - -def test_bmm(): - """test torch translator for bmm""" - - class BMM(Module): - def forward(self, x, y): - return torch.bmm(x, y) - - input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - verify_model(BMM(), input_info) - - -def test_baddbmm(): - """test torch translator for baddbmm""" - - class BAddBMM1(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y) - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - - input_info = [ - ((4, 128, 512), "float32"), - ((4, 128, 256), "float32"), - ((4, 256, 512), "float32"), - ] - verify_model(BAddBMM1(), input_info) - verify_model(BAddBMM2(), input_info) - - -def test_relu(): - """test torch translator for relu""" - - class ReLU(Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, data): - return self.relu(data) - - class ReLU1(Module): - def forward(self, data): - return torch.nn.functional.relu(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU(), input_info) - verify_model(ReLU1(), input_info) - - -def test_relu6(): - """test torch translator for relu6""" - - class ReLU6(Module): - def __init__(self): - super().__init__() - self.relu6 = torch.nn.ReLU6() - - def forward(self, data): - return self.relu6(data) - - input_info = [([10, 10], "float32")] - verify_model(ReLU6(), input_info) - - -def test_maxpool2d(): - """test torch translator for maxpool2d""" - - class MaxPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) - - def forward(self, data): - return self.pool(data) - - class MaxPool2d3(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info) - verify_model(MaxPool2d2(), input_info) - verify_model(MaxPool2d3(), input_info) - - -def test_avgpool2d(): - """test torch translator for avgpool2d""" - - class AvgPool2d(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) - - def forward(self, data): - return self.pool(data) - - class AvgPool2d2(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info) - verify_model(AvgPool2d2(), input_info) - - -def test_adaptive_avgpool2d(): - """test torch translator for adaptive_avgpool2d""" - - class AdaptiveAvgPool2d0(Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) - - def forward(self, data): - return self.pool(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info) - - -def test_flatten(): - """test torch translator for flatten""" - - class Flatten(Module): - def __init__(self): - super().__init__() - self.f = torch.nn.Flatten(2, -1) - - def forward(self, data): - return self.f(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Flatten(), input_info) - verify_model(torch.nn.Flatten(2, -1), input_info) - - -def test_batchnorm2d(): - """test torch translator for batchnorm2d""" - - class BatchNorm2d(Module): - def __init__(self): - super().__init__() - self.batchnorm = torch.nn.BatchNorm2d(3) - - def forward(self, data): - return self.batchnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(BatchNorm2d(), input_info) - - -def test_embedding(): - """test torch translator for embedding""" - - class Embedding(Module): - def __init__(self): - super().__init__() - self.embedding = torch.nn.Embedding(10, 3) - - def forward(self, data): - return self.embedding(data) - - verify_model(Embedding(), [([4], "int64")]) - verify_model(Embedding(), [([4, 5], "int64")]) - - -def test_layernorm(): - """test torch translator for layernorm""" - - class LayerNorm(Module): - def __init__(self): - super().__init__() - self.layernorm = torch.nn.LayerNorm((10, 10)) - - def forward(self, data): - return self.layernorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info) - - -def test_cross_entropy(): - """test torch translator for cross_entropy""" - - class CrossEntropy1(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss() - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy2(Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones((2,))) - self.loss = torch.nn.CrossEntropyLoss(weight=self.weight) - - def forward(self, logits, targets): - return self.loss(logits, targets) - - class CrossEntropy3(Module): - def __init__(self): - super().__init__() - self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum") - - def forward(self, logits, targets): - return self.loss(logits, targets) - - input_info = [([3, 2], "float32"), ([3], "int64")] - verify_model(CrossEntropy1(), input_info) - verify_model(CrossEntropy2(), input_info) - verify_model(CrossEntropy3(), input_info) - - -def test_silu(): - """test torch translator for silu""" - - class SiLU(Module): - def __init__(self): - super().__init__() - self.silu = torch.nn.SiLU() - - def forward(self, data): - return self.silu(data) - - class SiLU2(Module): - def forward(self, data): - return torch.nn.functional.silu(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info) - verify_model(SiLU2(), input_info) - - -def test_groupnorm(): - """test torch translator for groupnorm""" - - class GroupNorm(Module): - def __init__(self): - super().__init__() - self.groupnorm = torch.nn.GroupNorm(3, 3) - - def forward(self, data): - return self.groupnorm(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info) - - -def test_softmax(): - """test torch translator for softmax""" - - class Softmax(Module): - def __init__(self): - super().__init__() - self.softmax = torch.nn.Softmax(dim=1) - - def forward(self, data): - return self.softmax(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info) - - -def test_binary(): - """test torch translator for binary""" - - input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] - input_info2 = [([1, 3, 10, 10], "float32")] - - # Add - class Add1(Module): - def forward(self, lhs, rhs): - return lhs + rhs - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - - verify_model(Add1(), input_info1) - verify_model(Add2(), input_info2) - - # Sub - class Sub1(Module): - def forward(self, lhs, rhs): - return lhs - rhs - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - - verify_model(Sub1(), input_info1) - verify_model(Sub2(), input_info2) - - # Mul - class Mul1(Module): - def forward(self, lhs, rhs): - return lhs * rhs - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - - verify_model(Mul1(), input_info1) - verify_model(Mul2(), input_info2) - - # True div - class TrueDiv1(Module): - def forward(self, lhs, rhs): - return lhs / rhs - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - - verify_model(TrueDiv1(), input_info1) - verify_model(TrueDiv2(), input_info2) - - # Floor div - class FloorDiv1(Module): - def forward(self, lhs, rhs): - return lhs // rhs - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - - verify_model(FloorDiv1(), input_info1) - verify_model(FloorDiv2(), input_info2) - - # Power - class Power1(Module): - def forward(self, lhs, rhs): - return lhs**rhs - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - - verify_model(Power1(), input_info1) - verify_model(Power2(), input_info2) - - # LT - class LT1(Module): - def forward(self, lhs, rhs): - return lhs < rhs - - class LT2(Module): - def forward(self, lhs): - return lhs < 1.0 - - verify_model(LT1(), input_info1) - verify_model(LT2(), input_info2) - - -def test_size(): - """test torch translator for size""" - - class Size(Module): - def forward(self, data): - return data.size() - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Size(), input_info) - - -def test_squeeze(): - """test torch translator for squeeze""" - - class Squeeze1(Module): - def forward(self, data): - return data.squeeze(1) - - class Squeeze2(Module): - def forward(self, data): - return data.squeeze() - - input_info = [([3, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info) - verify_model(Squeeze2(), input_info) - - -def test_unsqueeze(): - """test torch translator for unsqueeze""" - - class Unsqueeze1(Module): - def forward(self, data): - return data.unsqueeze(1) - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info) - verify_model(Unsqueeze2(), input_info) - - -def test_getattr(): - """test torch translator for getattr""" - - class GetAttr1(Module): - def forward(self, data): - return data.shape - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(GetAttr1(), input_info) - - -def test_getitem(): - """test torch translator for getitem""" - - # TODO(tong.meng): strided_slice reshape bug for x[0, 1::2, :, :3] - class Slice1(Module): - def forward(self, x): - return x[0:1, 1::2, :, :3] - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - - verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - verify_model(Slice2(), [([8, 16], "float32")]) - - -def test_unary(): - """test torch translator for unary""" - - input_info = [([1, 3, 10, 10], "float32")] - - # sin - class Sin(Module): - def forward(self, data): - return torch.sin(data) - - verify_model(Sin(), input_info) - - # cos - class Cos(Module): - def forward(self, data): - return torch.cos(data) - - verify_model(Cos(), input_info) - - # exp - class Exp(Module): - def forward(self, data): - return torch.exp(data) - - verify_model(Exp(), input_info) - - # sqrt - class Sqrt(Module): - def forward(self, data): - return torch.sqrt(data) - - verify_model(Sqrt(), input_info) - - # sigmoid - class Sigmoid(Module): - def forward(self, data): - return torch.sigmoid(data) - - verify_model(Sigmoid(), input_info) - - # round - class Round(Module): - def forward(self, data): - return torch.round(data) - - verify_model(Round(), input_info) - - -def test_gelu(): - """test torch translator for gelu""" - - class Gelu(Module): - def forward(self, data): - return torch.nn.functional.gelu(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Gelu(), input_info) - - -def test_tanh(): - """test torch translator for tanh""" - - class Tanh(Module): - def forward(self, data): - return torch.tanh(data) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info) - - -def test_clamp(): - """test torch translator for clamp""" - - class Clamp(Module): - def forward(self, data): - return torch.clamp(data, min=0.1, max=0.5) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info) - - -def test_interpolate(): - """test torch translator for interpolate""" - - class Interpolate(Module): - def forward(self, data): - return torch.nn.functional.interpolate(data, (5, 5)) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info) - - -def test_addmm(): - """test torch translator for addmm""" - - class Addmm(Module): - def forward(self, x_1, x_2, x_3): - return torch.addmm(x_1, x_2, x_3) - - input_info = [ - ([10, 10], "float32"), - ([10, 10], "float32"), - ([10, 10], "float32"), - ] - verify_model(Addmm(), input_info) - - -def test_split(): - """test torch translator for split""" - - class Split1(Module): - def forward(self, data): - return torch.split(data, 1, dim=1) - - class Split2(Module): - def forward(self, data): - return torch.split(data, [1, 2], dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split1(), input_info) - verify_model(Split2(), input_info) - - -def test_unbind(): - """test torch translator for unbind""" - - class Unbind1(Module): - def forward(self, data): - return torch.unbind(data) - - class Unbind2(Module): - def forward(self, data): - return torch.unbind(data, dim=1) - - input_info = [([3, 3, 10, 10], "float32")] - verify_model(Unbind1(), input_info) - verify_model(Unbind2(), input_info) - - -def test_cumsum(): - """test torch translator for cumsum""" - - class Cumsum(Module): - def forward(self, data): - return torch.cumsum(data, dim=1, dtype=torch.int32) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Cumsum(), input_info) - - -def test_chunk(): - """test torch translator for chunk""" - - class Chunk(Module): - def forward(self, data): - return torch.chunk(data, 3, dim=1) - - input_info = [([1, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info) - - -def test_inplace_fill(): - """test torch translator for inplace_fill""" - - class InplaceFill(Module): - def forward(self, data): - data.fill_(1.5) - return data - - verify_model(InplaceFill(), [([10, 10], "float32")]) - - -def test_arange(): - """test torch translator for arange""" - - # pylint: disable=unused-argument - class Arange(Module): - def forward(self, data): - return torch.arange(0, 20, dtype=torch.int32) - - verify_model(Arange(), [([10, 10], "float32")]) - - -def test_tril(): - """test torch translator for tril""" - - class Tril(Module): - def forward(self, data): - return torch.tril(data, 1) - - class InplaceTril(Module): - def forward(self, data): - data.tril_(1) - return data - - input_info = [([10, 10], "float32")] - verify_model(Tril(), input_info) - verify_model(InplaceTril(), input_info) - - -def test_triu(): - """test torch translator for triu""" - - class Triu(Module): - def forward(self, data): - return torch.triu(data, 1) - - class InplaceTriu(Module): - def forward(self, data): - data.triu_(1) - return data - - input_info = [([10, 10], "float32")] - verify_model(Triu(), input_info) - verify_model(InplaceTriu(), input_info) - - -def test_new_ones(): - """test torch translator for new_ones""" - - class NewOnes(Module): - def forward(self, x): - return x.new_ones(1, 2, 3) - - input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info) - - -def test_expand(): - """test torch translator for expand""" - - class Expand1(Module): - def forward(self, x): - return x.expand(4, 2, 3, 4) - - class Expand2(Module): - def forward(self, x): - return x.expand(4, -1, -1, 4) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand1(), input_info) - verify_model(Expand2(), input_info) - - -def test_reduce(): - """test torch translator for reduce""" - - # sum - class Sum(Module): - def forward(self, x): - return torch.sum(x, (2, 1)) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Sum(), input_info) - - -def test_datatype(): - """test torch translator for datatype""" - - input_info = [([1, 2, 3, 4], "float32")] - - # float - class ToFloat(Module): - def forward(self, x): - return x.float() - - verify_model(ToFloat(), input_info) - - # half - class ToHalf(Module): - def forward(self, x): - return x.half() - - verify_model(ToHalf(), input_info) - - # type - class Type(Module): - def forward(self, x): - return x.type(torch.float32) - - verify_model(Type(), input_info) - - -def test_permute(): - """test torch translator for permute""" - - class Permute(Module): - def forward(self, x): - return x.permute(0, 3, 2, 1) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Permute(), input_info) - - -def test_reshape(): - """test torch translator for reshape""" - - class Reshape(Module): - def forward(self, x): - return x.reshape(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info) - - -def test_transpose(): - """test torch translator for transpose""" - - class Transpose(Module): - def forward(self, x): - return x.transpose(1, 3) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(Transpose(), input_info) - - -def test_view(): - """test torch translator for view""" - - class View(Module): - def forward(self, x): - return x.view(2, 12) - - input_info = [([1, 2, 3, 4], "float32")] - verify_model(View(), input_info) - - -def test_keep_params(): - """test torch translator for keep_params""" - - class Conv2D1(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) - - def forward(self, data): - return self.conv(data) - - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) - - -def test_unwrap_unit_return_tuple(): - """test torch translator for unwrap_unit_return_tuple""" - - class Identity(Module): - def forward(self, x): - return (x,) - - verify_model(Identity(), [([256, 256], "float32")]) - - -def test_no_bind_return_tuple(): - """test torch translator for no_bind_return_tuple""" - - class Identity(Module): - def forward(self, x, y): - return (x, y) - - input_info = [([256, 256], "float32"), ([256, 256], "float32")] - verify_model(Identity(), input_info) - - -def test_argmax(): - """test torch translator for argmax""" - - class Argmax1(Module): - def forward(self, data): - return torch.argmax(data, dim=-1) - - class Argmax2(Module): - def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True) - - verify_model(Argmax1(), [([256, 256], "float32")]) - verify_model(Argmax2(), [([256, 256], "float32")]) - - -def test_argmin(): - """test torch translator for argmin""" - - class Argmin1(Module): - def forward(self, data): - return torch.argmin(data) - - class Argmin2(Module): - def forward(self, data): - return torch.argmin(data, keepdim=True) - - verify_model(Argmin1(), [([256, 256], "float32")]) - verify_model(Argmin2(), [([256, 256], "float32")]) - - -def test_to(): - """test torch translator for to""" - - class To1(Module): - def forward(self, data): - return data.to(torch.float16) - - class To2(Module): - def forward(self, data): - return data.to("cpu") - - verify_model(To1(), [([256, 256], "float32")]) - verify_model(To2(), [([256, 256], "float32")]) - - -def test_mean(): - """test torch translator for mean""" - - class Mean(Module): - def forward(self, data): - return data.mean(-1) - - class MeanKeepDim(Module): - def forward(self, data): - return data.mean(-1, keepdim=True) - - verify_model(Mean(), [([256, 256], "float32")]) - verify_model(MeanKeepDim(), [([256, 256], "float32")]) - - -def test_rsqrt(): - """test torch translator for rsqrt""" - - class Rsqrt(Module): - def forward(self, data): - return torch.rsqrt(data) - - verify_model(Rsqrt(), [([256, 256], "float32")]) - - -def test_neg(): - """test torch translator for neg""" - - class Neg(Module): - def forward(self, data): - return -data - - verify_model(Neg(), [([256, 256], "float32")]) - - -def test_max(): - """test torch translator for max""" - - class Max(Module): - def forward(self, x, y): - return torch.max(x, y) - - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) - - -def test_cat(): - """test torch translator for cat""" - - class Cat1(Module): - def forward(self, data, data1, data2): - return torch.cat((data, data1, data2), dim=1) - - class Cat2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.cat((data, const1, const2), dim=1) - - input_info = [ - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ] - verify_model(Cat1(), input_info) - verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) - - -def test_stack(): - """test torch translator for stack""" - - class Stack1(Module): - def forward(self, data, data1, data2): - return torch.stack((data, data1, data2), dim=0) - - class Stack2(Module): - def forward(self, data): - const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) - return torch.stack((data, const1, const2), dim=1) - - input_info = [ - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ([1, 3, 10, 10], "float32"), - ] - verify_model(Stack1(), input_info) - verify_model(Stack2(), [([1, 3, 10, 10], "float32")]) - - -def test_scatter(): - """test torch translator for scatter""" - - class Scatter1(Module): - def __init__(self): - super().__init__() - self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) - - def forward(self, data, src): - return data.scatter(dim=0, index=self.index, src=src) - - class Scatter2(Module): - def forward(self, data, index, src): - return data.scatter(0, index, src) - - verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")]) - verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) - - -def test_masked_scatter(): - """test torch translator for masked_scatter""" - - class MaskedScatter1(Module): - def __init__(self): - super().__init__() - self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH) - - def forward(self, data, src): - return data.masked_scatter(self.mask, src) - - class MaskedScatter2(Module): - def __init__(self): - super().__init__() - self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH) - - def forward(self, data, src): - return data.masked_scatter(self.mask, src) - - verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")]) - verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) - - -def test_attention(): - """test torch translator for attention""" - - # pylint: disable=import-outside-toplevel - import torch.nn.functional as F - - class Attention1(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data) - - class Attention2(Module): - def forward(self, q_data, k_data, v_data): - return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True) - - input_info = [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ] - verify_model(Attention1(), input_info) - verify_model(Attention2(), input_info) - - class Attention3(Module): - def forward(self, q_data, k_data, v_data, mask): - return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - - verify_model( - Attention3(), - [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 128], "float32"), - ], - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 0678e8187776..c1ebd23f3e1b 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -43,4 +43,3 @@ echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake -echo set\(USE_MSC ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index dd35f7645379..a16847023a61 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -39,6 +39,5 @@ echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_CUTLASS ON\) >> config.cmake -echo set\(USE_MSC ON\) >> config.cmake echo set\(CMAKE_CUDA_ARCHITECTURES 75\) >> config.cmake echo set\(USE_CLML ON\) >> config.cmake diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 7fb79593c0ac..bfa955ff552d 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -36,8 +36,5 @@ find . -type f -path "*.pyc" | xargs rm -f # setup tvm-ffi into python folder python3 -m pip install -v --target=python ./3rdparty/tvm-ffi/ -# Test for MSC -pytest tests/python/contrib/test_msc - # Test for OpenCLML pytest tests/python/relax/backend/clml/